Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
20b1d690
Commit
20b1d690
authored
Sep 20, 2019
by
Paul
Browse files
Merge branch 'develop' into tests
parents
17aaaa1e
ba729cfc
Changes
281
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2450 additions
and
544 deletions
+2450
-544
src/targets/gpu/int8_gemm_pack.cpp
src/targets/gpu/int8_gemm_pack.cpp
+37
-0
src/targets/gpu/logsoftmax.cpp
src/targets/gpu/logsoftmax.cpp
+4
-4
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+81
-5
src/targets/gpu/pack_int8_args.cpp
src/targets/gpu/pack_int8_args.cpp
+74
-0
src/targets/gpu/quant_convolution.cpp
src/targets/gpu/quant_convolution.cpp
+125
-0
src/targets/gpu/relu.cpp
src/targets/gpu/relu.cpp
+0
-36
src/targets/gpu/sigmoid.cpp
src/targets/gpu/sigmoid.cpp
+0
-36
src/targets/gpu/softmax.cpp
src/targets/gpu/softmax.cpp
+13
-0
src/targets/gpu/tanh.cpp
src/targets/gpu/tanh.cpp
+0
-36
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+21
-7
src/targets/gpu/write_literals.cpp
src/targets/gpu/write_literals.cpp
+1
-1
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+2
-1
src/tf/tf.cpp
src/tf/tf.cpp
+549
-297
test/CMakeLists.txt
test/CMakeLists.txt
+2
-2
test/cpu_dot_op_test.cpp
test/cpu_dot_op_test.cpp
+390
-0
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+540
-59
test/eliminate_common_subexpression_test.cpp
test/eliminate_common_subexpression_test.cpp
+2
-2
test/eliminate_contiguous_test.cpp
test/eliminate_contiguous_test.cpp
+37
-3
test/eliminate_pad_test.cpp
test/eliminate_pad_test.cpp
+0
-19
test/gpu/ops_test.cpp
test/gpu/ops_test.cpp
+572
-36
No files found.
src/targets/gpu/int8_gemm_pack.cpp
0 → 100644
View file @
20b1d690
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_int8_gemm_pack_a
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{{
inputs
.
at
(
0
)},
*
this
}.
has
(
1
).
not_broadcasted
().
packed
();
return
inputs
.
at
(
0
);
}
argument
hip_int8_gemm_pack_a
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
int8_gemm_pack_a
(
ctx
.
get_stream
().
get
(),
args
[
1
],
args
[
0
]);
return
args
[
1
];
}
shape
hip_int8_gemm_pack_b
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{{
inputs
.
at
(
0
)},
*
this
}.
has
(
1
).
not_broadcasted
().
packed
();
return
inputs
.
at
(
0
);
}
argument
hip_int8_gemm_pack_b
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
int8_gemm_pack_b
(
ctx
.
get_stream
().
get
(),
args
[
1
],
args
[
0
]);
return
args
[
1
];
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/logsoftmax.cpp
View file @
20b1d690
...
...
@@ -15,11 +15,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
hip_logsoftmax
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
argument
hip_logsoftmax
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
logsoftmax
(
ctx
.
get_stream
().
get
(),
output_shape
,
args
,
op
.
axis
);
device
::
logsoftmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
...
...
src/targets/gpu/lowering.cpp
View file @
20b1d690
...
...
@@ -11,9 +11,12 @@
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp>
...
...
@@ -24,9 +27,12 @@
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/div.hpp>
#include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/erf.hpp>
#include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp>
#include <migraphx/gpu/sign.hpp>
#include <migraphx/gpu/cos.hpp>
#include <migraphx/gpu/tan.hpp>
#include <migraphx/gpu/sinh.hpp>
...
...
@@ -47,6 +53,14 @@
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <utility>
#include <functional>
#include <algorithm>
...
...
@@ -72,10 +86,8 @@ struct miopen_apply
void
init
()
{
this
->
last
=
instruction
::
get_output_alias
(
std
::
prev
(
prog
->
end
()));
add_miopen_simple_op
<
miopen_relu
>
(
"relu"
,
make_relu
);
add_miopen_simple_op
<
miopen_sigmoid
>
(
"sigmoid"
,
make_sigmoid
);
add_miopen_simple_op
<
miopen_abs
>
(
"abs"
,
make_abs
);
add_miopen_simple_op
<
miopen_tanh
>
(
"tanh"
,
make_tanh
);
add_miopen_extend_op
<
miopen_leaky_relu
,
op
::
leaky_relu
>
(
"leaky_relu"
,
make_leaky_relu
);
add_miopen_extend_op
<
miopen_elu
,
op
::
elu
>
(
"elu"
,
make_elu
);
...
...
@@ -83,31 +95,48 @@ struct miopen_apply
add_generic_op
<
hip_add
>
(
"add"
);
add_generic_op
<
hip_sub
>
(
"sub"
);
add_generic_op
<
hip_exp
>
(
"exp"
);
add_generic_op
<
hip_erf
>
(
"erf"
);
add_generic_op
<
hip_log
>
(
"log"
);
add_generic_op
<
hip_sin
>
(
"sin"
);
add_generic_op
<
hip_cos
>
(
"cos"
);
add_generic_op
<
hip_tan
>
(
"tan"
);
add_generic_op
<
hip_sinh
>
(
"sinh"
);
add_generic_op
<
hip_cosh
>
(
"cosh"
);
add_generic_op
<
hip_tanh
>
(
"tanh"
);
add_generic_op
<
hip_asin
>
(
"asin"
);
add_generic_op
<
hip_acos
>
(
"acos"
);
add_generic_op
<
hip_atan
>
(
"atan"
);
add_generic_op
<
hip_sqrt
>
(
"sqrt"
);
add_generic_op
<
hip_mul
>
(
"mul"
);
add_generic_op
<
hip_div
>
(
"div"
);
add_generic_op
<
hip_max
>
(
"max"
);
add_generic_op
<
hip_min
>
(
"min"
);
add_generic_op
<
hip_rsqrt
>
(
"rsqrt"
);
add_generic_op
<
hip_round
>
(
"round"
);
add_generic_op
<
hip_pow
>
(
"pow"
);
add_generic_op
<
hip_sqdiff
>
(
"sqdiff"
);
add_generic_op
<
hip_relu
>
(
"relu"
);
add_generic_op
<
hip_sign
>
(
"sign"
);
add_generic_op
<
hip_sigmoid
>
(
"sigmoid"
);
add_extend_op
<
miopen_gemm
,
op
::
dot
>
(
"dot"
);
add_extend_op
<
miopen_contiguous
,
op
::
contiguous
>
(
"contiguous"
);
add_extend_op
<
hip_concat
,
op
::
concat
>
(
"concat"
);
add_extend_op
<
miopen
_softmax
,
op
::
softmax
>
(
"softmax"
);
add_extend_op
<
hip
_softmax
,
op
::
softmax
>
(
"softmax"
);
add_extend_op
<
hip_logsoftmax
,
op
::
logsoftmax
>
(
"logsoftmax"
);
add_extend_op
<
hip_argmax
,
op
::
argmax
>
(
"argmax"
);
add_extend_op
<
hip_argmin
,
op
::
argmin
>
(
"argmin"
);
add_extend_op
<
hip_gather
,
op
::
gather
>
(
"gather"
);
add_extend_op
<
hip_pad
,
op
::
pad
>
(
"pad"
);
add_extend_op
<
hip_convert
,
op
::
convert
>
(
"convert"
);
add_extend_op
<
hip_clip
,
op
::
clip
>
(
"clip"
);
add_extend_op
<
hip_reduce_sum
,
op
::
reduce_sum
>
(
"reduce_sum"
);
add_extend_op
<
hip_reduce_mean
,
op
::
reduce_mean
>
(
"reduce_mean"
);
add_gemm_op
<
op
::
dot
>
(
"dot"
);
add_gemm_op
<
op
::
quant_dot
>
(
"quant_dot"
);
add_lrn_op
();
add_convolution_op
();
add_quant_convolution_op
();
add_pooling_op
();
add_batch_norm_inference_op
();
}
...
...
@@ -154,6 +183,53 @@ struct miopen_apply
});
}
template
<
class
Op
>
void
add_gemm_op
(
std
::
string
name
)
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
Op
>
(
ins
->
get_operator
());
auto
beta
=
op
.
beta
;
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
if
((
refs
.
size
()
==
2
)
or
(
refs
.
size
()
==
3
and
refs
.
back
()
->
outputs
().
size
()
>
1
)
or
(
ins
==
last
))
{
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
if
(
refs
.
size
()
==
2
)
{
beta
=
0
;
refs
.
push_back
(
output
);
}
else
{
auto
copy_out
=
prog
->
insert_instruction
(
ins
,
hip_copy
{},
refs
.
back
(),
output
);
refs
.
back
()
=
copy_out
;
refs
.
push_back
(
copy_out
);
}
}
else
{
refs
.
push_back
(
refs
.
back
());
}
return
prog
->
replace_instruction
(
ins
,
rocblas_gemm
<
Op
>
{
Op
{
op
.
alpha
,
beta
}},
refs
);
});
}
void
add_quant_convolution_op
()
{
apply_map
.
emplace
(
"quant_convolution"
,
[
=
](
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
op
::
quant_convolution
>
(
ins
->
get_operator
());
auto
conv
=
miopen_quant_convolution
{
op
,
make_conv
(
op
)};
auto
ws
=
conv
.
compile
(
ctx
,
ins
->
get_shape
(),
to_shapes
(
ins
->
inputs
()));
auto
args
=
ins
->
inputs
();
auto
workspace
=
insert_allocation
(
ins
,
ws
,
"workspace"
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
prog
->
replace_instruction
(
ins
,
conv
,
args
[
0
],
args
[
1
],
workspace
,
output
);
});
}
void
add_pooling_op
()
{
apply_map
.
emplace
(
"pooling"
,
[
=
](
instruction_ref
ins
)
{
...
...
src/targets/gpu/pack_int8_args.cpp
0 → 100644
View file @
20b1d690
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
void
pack_int8_args
::
apply
(
program
&
p
)
const
{
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
->
name
()
==
"gpu::quant_gemm"
)
{
auto
inputs
=
ins
->
inputs
();
bool
transa
=
inputs
[
0
]
->
get_shape
().
transposed
();
bool
transb
=
inputs
[
1
]
->
get_shape
().
transposed
();
if
(
!
transb
)
{
auto
packed_b
=
p
.
insert_instruction
(
ins
,
hip_allocate
{
inputs
[
1
]
->
get_shape
()});
auto
output_b
=
p
.
insert_instruction
(
ins
,
hip_int8_gemm_pack_a
{},
{
inputs
[
1
],
packed_b
});
instruction
::
replace_argument
(
ins
,
inputs
[
1
],
output_b
);
}
if
(
transa
)
{
auto
packed_a
=
p
.
insert_instruction
(
ins
,
hip_allocate
{
inputs
[
0
]
->
get_shape
()});
auto
output_a
=
p
.
insert_instruction
(
ins
,
hip_int8_gemm_pack_b
{},
{
inputs
[
0
],
packed_a
});
instruction
::
replace_argument
(
ins
,
inputs
[
0
],
output_a
);
}
}
else
if
(
ins
->
name
()
==
"gpu::quant_convolution"
)
{
auto
inputs
=
ins
->
inputs
();
auto
packed_x
=
p
.
insert_instruction
(
ins
,
hip_allocate
{
pack_int8_shape
(
inputs
[
0
]
->
get_shape
())});
auto
output_x
=
p
.
insert_instruction
(
ins
,
miopen_int8_conv_pack
{},
{
inputs
[
0
],
packed_x
});
instruction
::
replace_argument
(
ins
,
inputs
[
0
],
output_x
);
auto
packed_w
=
p
.
insert_instruction
(
ins
,
hip_allocate
{
pack_int8_shape
(
inputs
[
1
]
->
get_shape
())});
auto
output_w
=
p
.
insert_instruction
(
ins
,
miopen_int8_conv_pack
{},
{
inputs
[
1
],
packed_w
});
instruction
::
replace_argument
(
ins
,
inputs
[
1
],
output_w
);
}
}
}
shape
pack_int8_args
::
pack_int8_shape
(
const
shape
&
s
)
const
{
if
(
s
.
type
()
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"PACK_INT8_ARGS: only process int8_type"
);
}
auto
lens
=
s
.
lens
();
auto
strides
=
s
.
strides
();
lens
[
1
]
=
(
lens
[
1
]
+
3
)
/
4
*
4
;
strides
[
0
]
=
strides
[
1
]
*
lens
[
1
];
return
{
s
.
type
(),
lens
,
strides
};
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/quant_convolution.cpp
0 → 100644
View file @
20b1d690
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
miopen_quant_convolution
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
argument
miopen_quant_convolution
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
(),
true
);
auto
w_desc
=
make_tensor
(
args
[
1
].
get_shape
(),
true
);
auto
y_desc
=
make_tensor
(
output_shape
);
float
alpha
=
1
;
float
beta
=
0
;
auto
status
=
miopenConvolutionForward
(
ctx
.
get_stream
().
get_miopen
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
w_desc
.
get
(),
args
[
1
].
implicit
(),
cd
.
get
(),
algo
,
&
beta
,
y_desc
.
get
(),
args
[
3
].
implicit
(),
args
[
2
].
implicit
(),
args
[
2
].
get_shape
().
bytes
());
if
(
status
!=
miopenStatusSuccess
)
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: run convolution forward failed"
);
}
return
args
[
3
];
}
shape
miopen_quant_convolution
::
compile
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
{
shape
workspace_shape
{};
auto
x_desc
=
make_tensor
(
inputs
[
0
],
true
);
auto
w_desc
=
make_tensor
(
inputs
[
1
],
true
);
auto
y_desc
=
make_tensor
(
output_shape
);
std
::
size_t
workspace_size
=
0
;
miopenConvolutionForwardGetWorkSpaceSize
(
ctx
.
get_stream
().
get_miopen
(),
w_desc
.
get
(),
x_desc
.
get
(),
cd
.
get
(),
y_desc
.
get
(),
&
workspace_size
);
workspace_shape
=
shape
{
shape
::
int8_type
,
{
workspace_size
}};
auto
arg_vec4_x
=
to_gpu
(
generate_argument
(
pack_int8_shape
(
inputs
[
0
])));
auto
arg_vec4_w
=
to_gpu
(
generate_argument
(
pack_int8_shape
(
inputs
[
1
])));
auto
y
=
allocate_gpu
(
output_shape
);
auto
workspace
=
allocate_gpu
(
workspace_shape
);
int
algo_count
=
1
;
miopenConvAlgoPerf_t
perf
;
auto
status
=
miopenFindConvolutionForwardAlgorithm
(
ctx
.
get_stream
().
get_miopen
(),
x_desc
.
get
(),
arg_vec4_x
.
implicit
(),
w_desc
.
get
(),
arg_vec4_w
.
implicit
(),
cd
.
get
(),
y_desc
.
get
(),
y
.
implicit
(),
1
,
&
algo_count
,
&
perf
,
workspace
.
implicit
(),
workspace_size
,
false
);
if
(
status
!=
miopenStatusSuccess
)
{
MIGRAPHX_THROW
(
"QUANT_CONVOLUTION: find convolution failed"
);
}
handle
=
ctx
.
get_stream
().
get_miopen
();
algo
=
perf
.
fwd_algo
;
return
shape
{
shape
::
int8_type
,
{
perf
.
memory
}};
}
void
miopen_quant_convolution
::
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
std
::
vector
<
shape
>
inputs
)
{
if
(
handle
==
ctx
.
get_stream
().
get_miopen
())
return
;
// Check that workspace hasn't changed
auto
size
=
inputs
.
at
(
2
).
bytes
();
auto
ws
=
compile
(
ctx
,
output_shape
,
std
::
move
(
inputs
));
if
(
ws
.
bytes
()
>
size
)
MIGRAPHX_THROW
(
"Workspace has changed during finalization."
);
}
shape
miopen_quant_convolution
::
pack_int8_shape
(
const
shape
&
s
)
const
{
if
(
s
.
type
()
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"PACK_INT8_SHAPE: only process int8_type"
);
}
auto
lens
=
s
.
lens
();
auto
strides
=
s
.
strides
();
lens
[
1
]
=
(
lens
[
1
]
+
3
)
/
4
*
4
;
strides
[
0
]
=
strides
[
1
]
*
lens
[
1
];
return
{
s
.
type
(),
lens
,
strides
};
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/relu.cpp
deleted
100644 → 0
View file @
17aaaa1e
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
miopen_relu
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
not_broadcasted
();
return
inputs
.
at
(
1
);
}
argument
miopen_relu
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
float
alpha
=
1
;
float
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
ctx
.
get_stream
().
get_miopen
(),
ad
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
&
beta
,
y_desc
.
get
(),
args
[
1
].
implicit
());
return
args
[
1
];
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/sigmoid.cpp
deleted
100644 → 0
View file @
17aaaa1e
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
miopen_sigmoid
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
not_broadcasted
();
return
inputs
.
at
(
1
);
}
argument
miopen_sigmoid
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
float
alpha
=
1
;
float
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
ctx
.
get_stream
().
get_miopen
(),
ad
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
&
beta
,
y_desc
.
get
(),
args
[
1
].
implicit
());
return
args
[
1
];
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/softmax.cpp
View file @
20b1d690
#include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
...
...
@@ -30,6 +31,18 @@ argument miopen_softmax::compute(context& ctx,
return
args
[
1
];
}
shape
hip_softmax
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
standard
();
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
hip_softmax
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
softmax
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
axis
);
return
args
.
back
();
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/tanh.cpp
deleted
100644 → 0
View file @
17aaaa1e
#include <migraphx/gpu/tanh.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
miopen_tanh
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
).
packed
();
return
inputs
.
at
(
0
);
}
argument
miopen_tanh
::
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
float
alpha
=
1
;
float
beta
=
0
;
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
miopenActivationForward
(
ctx
.
get_stream
().
get_miopen
(),
ad
.
get
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
&
beta
,
y_desc
.
get
(),
args
[
1
].
implicit
());
return
args
[
1
];
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
20b1d690
...
...
@@ -13,14 +13,16 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression
_elimination
.hpp>
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
eliminate_
common_subexpression.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
...
...
@@ -36,23 +38,26 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off
return
{
dead_code_elimination
{},
simplify_reshapes
{},
dead_code_elimination
{},
eliminate_identity
{},
eliminate_pad
{},
dead_code_elimination
{},
fwd_conv
_batchnorm
_rewrite
{},
rewrite
_batchnorm
{},
dead_code_elimination
{},
rewrite_rnn
{},
rewrite_pooling
{},
dead_code_elimination
{},
//common_subexpression_elimination{},
//dead_code_elimination{},
simplify_algebra
{},
eliminate_common_subexpression
{},
dead_code_elimination
{},
propagate_constant
{},
simplify_algebra
{},
dead_code_elimination
{},
auto_contiguous
{},
simplify_reshapes
{},
dead_code_elimination
{},
propagate_constant
{},
dead_code_elimination
{},
lowering
{
ctx
},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
...
...
@@ -60,6 +65,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination
{},
adjust_allocation
{},
dead_code_elimination
{},
pack_int8_args
{},
dead_code_elimination
{},
fuse_ops
{
&
ctx
},
dead_code_elimination
{},
write_literals
{
&
ctx
},
...
...
@@ -78,6 +85,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
std
::
string
target
::
name
()
const
{
return
"miopen"
;
}
migraphx
::
context
target
::
get_context
()
const
{
return
context
{};
}
argument
target
::
copy_to
(
const
argument
&
arg
)
const
{
return
gpu
::
to_gpu
(
arg
);
}
argument
target
::
copy_from
(
const
argument
&
arg
)
const
{
return
gpu
::
from_gpu
(
arg
);
}
argument
target
::
allocate
(
const
shape
&
s
)
const
{
return
gpu
::
allocate_gpu
(
s
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/write_literals.cpp
View file @
20b1d690
...
...
@@ -45,7 +45,7 @@ void write_literals::apply(program& p) const
literal
l
=
ins
->
get_literal
();
auto
pre
=
p
.
add_literal
(
l
);
auto
alloc
=
p
.
insert_instruction
(
std
::
next
(
pre
),
hip_allocate
{
l
.
get_shape
()});
p
.
replace_instruction
(
ins
,
hip_copy
{},
pre
,
alloc
);
p
.
replace_instruction
(
ins
,
hip_copy
_to_gpu
{},
pre
,
alloc
);
}
else
{
...
...
src/tf/CMakeLists.txt
View file @
20b1d690
...
...
@@ -21,6 +21,7 @@ set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library
(
migraphx_tf tf.cpp
)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
rocm_set_soversion
(
migraphx_tf
${
PROJECT_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_tf
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
)
target_link_libraries
(
migraphx_tf PUBLIC migraphx
)
...
...
@@ -31,7 +32,7 @@ rocm_install_targets(
add_executable
(
read_tf read_tf.cpp
)
rocm_clang_tidy_check
(
read_tf
)
target_link_libraries
(
read_tf migraphx_tf
)
target_link_libraries
(
read_tf migraphx_tf
migraphx_cpu
)
if
(
MIGRAPHX_ENABLE_GPU
)
add_executable
(
verify_tf verify_tf.cpp
)
...
...
src/tf/tf.cpp
View file @
20b1d690
...
...
@@ -17,6 +17,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/pad_calc.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -24,8 +25,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
tf_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
AttrValue
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
NodeDef
>
;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using
node_map
=
std
::
map
<
std
::
string
,
tensorflow
::
NodeDef
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
...
...
@@ -36,7 +36,50 @@ struct tf_parser
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
)
const
bool
should_transpose
(
instruction_ref
ins
)
const
{
return
is_nhwc
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
}
instruction_ref
to_nhwc
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
2
,
3
,
1
}},
ins
);
return
ins
;
}
instruction_ref
to_nchw
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
ins
);
return
ins
;
}
instruction_ref
to_kcxy
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
ins
);
return
ins
;
}
instruction_ref
make_contiguous
(
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
return
ins
;
else
return
prog
.
add_instruction
(
op
::
contiguous
{},
ins
);
}
std
::
vector
<
instruction_ref
>
to_nchw
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
instruction_ref
>
result
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
result
.
begin
(),
[
&
](
auto
ins
)
{
return
this
->
to_nchw
(
ins
);
});
return
result
;
}
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
,
const
size_t
num_dims
)
const
{
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
std
::
vector
<
size_t
>
axes
;
...
...
@@ -44,14 +87,14 @@ struct tf_parser
if
(
is_nhwc
)
{
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
axes
.
begin
(),
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
return
parse_axis
(
axis
,
num_dims
);
});
}
return
axes
;
}
template
<
class
T
>
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
)
const
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
,
const
size_t
num_dims
)
const
{
if
(
is_nhwc
)
{
...
...
@@ -59,7 +102,7 @@ struct tf_parser
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
new_axes
),
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
});
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
,
num_dims
);
});
return
new_axes
;
}
return
axes
;
...
...
@@ -74,17 +117,17 @@ struct tf_parser
std
::
vector
<
T
>
new_data
(
prev_data
.
size
());
for
(
size_t
i
=
0
;
i
<
new_data
.
size
();
i
++
)
{
auto
new_idx
=
parse_axis
(
i
);
auto
new_idx
=
parse_axis
(
i
,
new_data
.
size
()
);
new_data
.
at
(
new_idx
)
=
prev_data
.
at
(
i
);
}
prev_data
=
new_data
;
}
template
<
class
T
>
T
parse_axis
(
const
T
&
dim
)
const
T
parse_axis
(
const
T
&
dim
,
const
size_t
num_dims
)
const
{
T
new_dim
=
dim
;
if
(
is_nhwc
)
if
(
is_nhwc
and
num_dims
>=
4
)
{
switch
(
dim
)
{
...
...
@@ -105,70 +148,109 @@ struct tf_parser
return
axes
;
}
std
::
vector
<
int64_t
>
get_axes_from_mask
(
const
size_t
num_axes
,
const
uint32_t
mask
)
{
uint32_t
bitwise_compare
=
1
;
std
::
vector
<
int64_t
>
axes
;
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if
(((
mask
>>
i
)
&
bitwise_compare
)
==
1
)
axes
.
push_back
(
1
);
else
axes
.
push_back
(
0
);
}
return
axes
;
}
tf_parser
()
{
add_generic_op
(
"All"
,
op
::
identity
{});
add_generic_op
(
"Identity"
,
op
::
identity
{});
add_generic_op
(
"LessEqual"
,
op
::
identity
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu6"
,
op
::
clip
{
6.0
,
0.0
});
add_generic_op
(
"Rsqrt"
,
op
::
rsqrt
{});
add_generic_op
(
"Tanh"
,
op
::
tanh
{});
add_generic_op
(
"StopGradient"
,
op
::
identity
{});
add_binary_op
(
"Add"
,
op
::
add
{});
add_binary_op
(
"Mul"
,
op
::
mul
{});
add_binary_op
(
"Pow"
,
op
::
pow
{});
add_binary_op
(
"SquaredDifference"
,
op
::
sqdiff
{});
add_binary_op
(
"Sub"
,
op
::
sub
{});
add_mem_op
(
"AvgPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"BatchMatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"BatchMatMulV2"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
);
add_mem_op
(
"Cast"
,
&
tf_parser
::
parse_cast
,
false
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"ExpandDims"
,
&
tf_parser
::
parse_expanddims
,
false
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"GatherV2"
,
&
tf_parser
::
parse_gather
,
false
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
,
false
);
add_mem_op
(
"OneHot"
,
&
tf_parser
::
parse_onehot
,
false
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Slice"
,
&
tf_parser
::
parse_slice
,
false
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
<
op
::
softmax
>
,
false
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
,
false
);
add_mem_op
(
"Transpose"
,
&
tf_parser
::
parse_transpose
,
false
);
}
// Multi output op
template
<
class
F
>
void
add_
multi_
op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
ops
.
emplace
(
name
,
f
);
if
(
transpose
)
{
ops
.
emplace
(
name
,
op_func
{[
=
](
const
attribute_map
&
attributes
,
const
std
::
vector
<
instruction_ref
>&
args
)
->
instruction_ref
{
return
to_nhwc
(
f
(
attributes
,
to_nchw
(
args
)));
}});
}
else
{
ops
.
emplace
(
name
,
f
);
}
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
transpose
);
}
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
auto
l0
=
args
[
1
];
if
(
contains
(
attributes
,
"data_format"
))
{
if
(
is_nhwc
)
{
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
args
[
1
]);
}
}
return
add_broadcastable_binary_op
(
args
[
0
],
l0
,
x
);
});
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
// TODO
// if(contains(attributes, "data_format"))
// {
// if(is_nhwc)
// {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
// }
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
},
false
);
}
template
<
class
T
>
...
...
@@ -207,20 +289,22 @@ struct tf_parser
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
to_nchw
(
l0
)
,
to_nchw
(
l1
))
);
}
else
{
return
prog
.
add_instruction
(
x
,
{
arg0
,
arg1
}
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
{
to_nchw
(
arg0
)
,
to_nchw
(
arg1
)})
);
}
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
},
transpose
);
}
instruction_ref
...
...
@@ -245,12 +329,19 @@ struct tf_parser
return
prog
.
add_instruction
(
op
::
add
{},
args
[
0
],
l0
);
}
instruction_ref
parse_cast
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
shape
::
type_t
type
=
parse_type
(
attributes
.
at
(
"DstT"
).
type
());
return
prog
.
add_instruction
(
op
::
convert
{
type
},
std
::
move
(
args
));
}
instruction_ref
parse_concat
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
// get index for axis within args
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
size_t
axis
=
parse_axis
(
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
()
)
;
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
op
::
concat
op
{
axis
};
// return only first N arguments (assuming last index is the axis value)
return
prog
.
add_instruction
(
...
...
@@ -261,45 +352,14 @@ struct tf_parser
attribute_map
attributes
,
const
std
::
vector
<
instruction_ref
>&
)
{
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
auto
l0
=
prog
.
add_literal
(
v
);
size_t
num_axes
=
l0
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
std
::
vector
<
int64_t
>
transpose_axes
=
get_axes
(
num_axes
);
reorder_data
(
transpose_axes
);
l0
=
prog
.
add_instruction
(
op
::
transpose
{
transpose_axes
},
l0
);
}
return
l0
;
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
return
prog
.
add_literal
(
v
);
}
instruction_ref
parse_conv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
convolution
op
;
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
}
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
{
std
::
vector
<
size_t
>
padding
;
copy
(
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
std
::
back_inserter
(
padding
));
if
(
padding
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
{
MIGRAPHX_THROW
(
"migraphx does not support asymetric padding"
);
}
op
.
padding
[
0
]
=
padding
[
0
];
op
.
padding
[
1
]
=
padding
[
1
];
}
}
if
(
contains
(
attributes
,
"strides"
))
{
std
::
vector
<
size_t
>
stride
;
...
...
@@ -324,22 +384,58 @@ struct tf_parser
op
.
dilation
[
0
]
=
dilation
[
2
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
auto
weights
=
to_kcxy
(
args
[
1
]);
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
if
(
is_nhwc
)
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
std
::
vector
<
size_t
>
weight_dims
=
weights
->
get_shape
().
lens
();
size_t
weight_h
=
weight_dims
[
2
];
size_t
weight_w
=
weight_dims
[
3
];
auto
input_dims
=
l0
->
get_shape
().
lens
();
size_t
input_h
=
input_dims
[
2
];
size_t
input_w
=
input_dims
[
3
];
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_h
,
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_w
,
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
padding
},
l0
);
}
else
{
op
.
padding
[
0
]
=
pads
[
0
];
op
.
padding
[
1
]
=
pads
[
1
];
}
}
else
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
op
.
padding_mode
=
op
::
padding_mode_t
::
valid
;
}
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
{
std
::
vector
<
size_t
>
padding
;
copy
(
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
std
::
back_inserter
(
padding
));
if
(
padding
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
{
MIGRAPHX_THROW
(
"migraphx does not support asymetric padding"
);
}
op
.
padding
[
0
]
=
padding
[
0
];
op
.
padding
[
1
]
=
padding
[
1
];
}
}
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
to_kcxy
(
args
[
1
])});
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
...
...
@@ -349,14 +445,7 @@ struct tf_parser
op
::
convolution
op
;
size_t
num_channels
=
args
[
0
]
->
get_shape
().
lens
()[
1
];
op
.
group
=
num_channels
;
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
}
}
if
(
contains
(
attributes
,
"strides"
))
{
std
::
vector
<
size_t
>
stride
;
...
...
@@ -369,17 +458,54 @@ struct tf_parser
op
.
stride
[
0
]
=
stride
[
2
];
op
.
stride
[
1
]
=
stride
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param
"
)
auto
weights
=
to_kcxy
(
args
[
1
]);
if
(
contains
(
attributes
,
"dilations
"
)
)
{
if
(
is_nhwc
)
std
::
vector
<
size_t
>
dilation
;
copy
(
attributes
.
at
(
"dilations"
).
list
().
i
(),
std
::
back_inserter
(
dilation
));
reorder_data
(
dilation
);
if
(
dilation
.
size
()
!=
4
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]
);
MIGRAPHX_THROW
(
"dilation should have 4 values"
);
}
else
op
.
dilation
[
0
]
=
dilation
[
2
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
std
::
vector
<
size_t
>
weight_dims
=
weights
->
get_shape
().
lens
();
size_t
weight_h
=
weight_dims
[
2
];
size_t
weight_w
=
weight_dims
[
3
];
auto
input_dims
=
l0
->
get_shape
().
lens
();
size_t
input_h
=
input_dims
[
2
];
size_t
input_w
=
input_dims
[
3
];
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_h
,
op
.
stride
[
0
],
op
.
dilation
[
0
],
weight_h
);
calculate_padding
(
1
,
pads
,
input_w
,
op
.
stride
[
1
],
op
.
dilation
[
1
],
weight_w
);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
padding
},
l0
);
}
else
{
op
.
padding
[
0
]
=
pads
[
0
];
op
.
padding
[
1
]
=
pads
[
1
];
}
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
valid
;
}
}
...
...
@@ -394,10 +520,37 @@ struct tf_parser
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
// Make sure weights are contiguous before doing reshape
auto
c
weights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
c
weights
);
auto
new_
weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
make_contiguous
(
weights
)
)
;
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
new_weights
});
}
instruction_ref
parse_expanddims
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
vector
<
size_t
>
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
new_dims
(
input_dims
.
begin
(),
input_dims
.
end
());
size_t
num_dims
=
input_dims
.
size
();
int32_t
dim
=
args
[
1
]
->
eval
().
at
<
int32_t
>
();
if
(
dim
<
0
)
{
new_dims
.
insert
(
new_dims
.
begin
()
+
(
num_dims
+
dim
+
1
),
1
);
}
else
{
new_dims
.
insert
(
new_dims
.
begin
()
+
dim
,
1
);
}
return
prog
.
add_instruction
(
op
::
reshape
{
new_dims
},
args
[
0
]);
}
instruction_ref
parse_gather
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
int
axis
=
args
[
2
]
->
eval
().
at
<
int32_t
>
();
op
::
gather
op
{
axis
};
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
args
[
1
]});
}
instruction_ref
...
...
@@ -412,7 +565,16 @@ struct tf_parser
}
if
(
contains
(
attributes
,
"transpose_b"
))
{
transb
=
attributes
.
at
(
"transpose_a"
).
b
();
transb
=
attributes
.
at
(
"transpose_b"
).
b
();
}
if
(
contains
(
attributes
,
"adj_x"
))
{
transa
=
attributes
.
at
(
"adj_x"
).
b
();
}
if
(
contains
(
attributes
,
"adj_y"
))
{
transb
=
attributes
.
at
(
"adj_y"
).
b
();
}
std
::
vector
<
int64_t
>
perm
(
args
[
0
]
->
get_shape
().
lens
().
size
());
...
...
@@ -429,23 +591,44 @@ struct tf_parser
instruction_ref
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
());
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
// check if conditions for GlobalAvgPool are met
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
if
(
axes
==
hw_axes
and
lens
.
size
()
==
4
)
auto
axes
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
<
int64_t
>
();
if
(
keep_dims
)
{
return
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
args
[
0
]);
}
else
{
op
::
pooling
op
{
"average"
};
op
.
lengths
[
0
]
=
lens
[
2
];
op
.
lengths
[
1
]
=
lens
[
3
];
auto
l0
=
prog
.
add_instruction
(
op
,
args
.
front
());
if
(
keep_dims
)
return
l0
;
return
prog
.
add_instruction
(
op
::
squeeze
{
std
::
vector
<
int64_t
>
(
hw_axes
.
begin
(),
hw_axes
.
end
())},
l0
);
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
axes
},
ins
);
}
MIGRAPHX_THROW
(
"MIGraphX does not support mean outside of GlobalAvgPool transformation"
);
}
instruction_ref
parse_onehot
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
size_t
depth
=
static_cast
<
size_t
>
(
args
[
1
]
->
eval
().
at
<
int32_t
>
());
int64_t
axis
=
-
1
;
float
on_value
=
args
[
2
]
->
eval
().
at
<
float
>
();
float
off_value
=
args
[
3
]
->
eval
().
at
<
float
>
();
std
::
vector
<
float
>
depth_input
(
depth
*
depth
,
off_value
);
for
(
int
i
=
0
;
i
<
depth
;
i
++
)
{
depth_input
[
depth
*
i
+
i
]
=
on_value
;
}
if
(
contains
(
attributes
,
"axis"
))
axis
=
attributes
.
at
(
"axis"
).
i
();
if
(
axis
==
-
1
)
{
shape
s
{
shape
::
float_type
,
{
depth
,
depth
}};
auto
l0
=
prog
.
add_literal
({
s
,
depth_input
});
return
prog
.
add_instruction
(
op
::
gather
{
0
},
{
l0
,
args
[
0
]});
}
MIGRAPHX_THROW
(
"MIGraphX does not support axis != -1"
);
}
instruction_ref
parse_pack
(
const
std
::
string
&
,
...
...
@@ -463,16 +646,14 @@ struct tf_parser
MIGRAPHX_THROW
(
"TF_PARSER: axis value of "
+
to_string
(
axis
)
+
" must be smaller than input size "
+
to_string
(
input_size
));
}
// check if input arg needs axis to be converted to NCHW
if
(
input_size
>=
4
)
axis
=
parse_axis
(
axis
);
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
instruction_ref
arg
)
{
return
prog
.
add_instruction
(
op
::
unsqueeze
{{
axis
}},
arg
);
});
return
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
);
return
to_nhwc
(
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
));
}
instruction_ref
...
...
@@ -508,18 +689,6 @@ struct tf_parser
{
op
::
pooling
op
{
starts_with
(
name
,
"Max"
)
?
"max"
:
"average"
};
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
valid
;
}
}
if
(
contains
(
attributes
,
"strides"
))
{
std
::
vector
<
size_t
>
stride
;
...
...
@@ -544,7 +713,39 @@ struct tf_parser
op
.
lengths
[
0
]
=
ksize
[
2
];
op
.
lengths
[
1
]
=
ksize
[
3
];
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
auto
input_dims
=
l0
->
get_shape
().
lens
();
size_t
input_h
=
input_dims
[
2
];
size_t
input_w
=
input_dims
[
3
];
std
::
vector
<
int64_t
>
pads
(
input_dims
.
size
());
calculate_padding
(
0
,
pads
,
input_h
,
op
.
stride
[
0
],
1
,
op
.
lengths
[
0
]);
calculate_padding
(
1
,
pads
,
input_w
,
op
.
stride
[
1
],
1
,
op
.
lengths
[
1
]);
if
(
pads
[
0
]
!=
pads
[
2
]
||
pads
[
1
]
!=
pads
[
3
])
{
std
::
vector
<
int64_t
>
padding
=
{
0
,
0
,
pads
[
0
],
pads
[
1
],
0
,
0
,
pads
[
2
],
pads
[
3
]};
l0
=
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
padding
,
std
::
numeric_limits
<
float
>::
lowest
()},
l0
);
}
else
{
op
.
padding
[
0
]
=
pads
[
0
];
op
.
padding
[
1
]
=
pads
[
1
];
}
}
else
if
(
pad_mode
.
find
(
"VALID"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
valid
;
}
}
return
prog
.
add_instruction
(
op
,
l0
);
}
instruction_ref
...
...
@@ -555,7 +756,7 @@ struct tf_parser
MIGRAPHX_THROW
(
"reshape needs 2 arguments (input, new_shape)"
);
auto
s
=
args
[
1
]
->
eval
();
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
void
parse_from
(
std
::
istream
&
is
)
...
...
@@ -572,13 +773,46 @@ struct tf_parser
}
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
parse_slice
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
slice
op
;
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
size
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
axes
=
args
[
0
]
->
get_shape
().
lens
();
size_t
num_axes
=
axes
.
size
();
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
num_axes
);
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
if
(
size
[
i
]
==
-
1
)
op
.
ends
[
i
]
=
axes
[
i
];
else
op
.
ends
[
i
]
=
starts
[
i
]
+
size
[
i
];
}
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
]));
}
// template to facilitate the logsoftmax later
template
<
class
Op
>
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
dims
=
args
.
front
()
->
get_shape
().
lens
();
auto
r
=
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
args
.
front
());
auto
s
=
prog
.
add_instruction
(
op
::
softmax
{},
r
);
return
prog
.
add_instruction
(
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
int
axis
=
-
1
;
auto
num_dims
=
args
[
0
]
->
get_shape
().
lens
().
size
();
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
static_cast
<
int
>
(
attributes
.
at
(
"axis"
).
i
());
}
if
(
axis
<
0
)
{
axis
+=
num_dims
;
}
return
prog
.
add_instruction
(
Op
{
axis
},
make_contiguous
(
args
[
0
]));
}
instruction_ref
parse_squeeze
(
const
std
::
string
&
,
...
...
@@ -586,20 +820,21 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
args
)
{
op
::
squeeze
op
;
auto
axes
=
parse_axes
(
attributes
,
"squeeze_dims"
);
auto
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
auto
axes
=
attributes
.
at
(
"squeeze_dims"
).
list
().
i
();
copy
(
axes
,
std
::
back_inserter
(
op
.
axes
));
auto
args0_dims
=
args
[
0
]
->
get_shape
().
lens
();
if
(
op
.
axes
.
empty
())
// no squeeze_dims provided, remove any dim that equals 1
{
for
(
size_t
i
=
0
;
i
<
args0
_dims
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
input
_dims
.
size
();
i
++
)
{
if
(
args0
_dims
.
at
(
i
)
==
1
)
if
(
input
_dims
.
at
(
i
)
==
1
)
{
op
.
axes
.
push_back
(
i
);
}
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
instruction_ref
parse_stridedslice
(
const
std
::
string
&
,
...
...
@@ -607,39 +842,68 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
args
)
{
op
::
slice
op
;
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
size_t
num_axes
=
args
[
0
]
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
reorder_data
(
starts
);
reorder_data
(
ends
);
}
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
l0
=
args
[
0
];
size_t
num_axes
=
l0
->
get_shape
().
lens
().
size
();
std
::
vector
<
size_t
>
axes
=
l0
->
get_shape
().
lens
();
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
ends
.
begin
(),
ends
.
end
());
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
uint32_t
begin_mask
=
0
;
uint32_t
end_mask
=
0
;
uint32_t
shrink_axis_mask
=
0
;
uint32_t
bitwise_compare
=
1
;
std
::
vector
<
int64_t
>
squeeze_axes
;
if
(
contains
(
attributes
,
"begin_mask"
))
begin_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"begin_mask"
).
i
());
if
(
contains
(
attributes
,
"end_mask"
))
end_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"end_mask"
).
i
());
if
(
contains
(
attributes
,
"shrink_axis_mask"
))
shrink_axis_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"shrink_axis_mask"
).
i
());
std
::
vector
<
int64_t
>
begin_axes
=
get_axes_from_mask
(
num_axes
,
begin_mask
);
std
::
vector
<
int64_t
>
end_axes
=
get_axes_from_mask
(
num_axes
,
end_mask
);
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
if
(
begin_axes
.
at
(
i
)
==
1
)
{
op
.
starts
.
at
(
i
)
=
0
;
}
if
(
end_axes
.
at
(
i
)
==
1
)
{
op
.
ends
.
at
(
i
)
=
axes
.
at
(
i
);
}
}
auto
l1
=
prog
.
add_instruction
(
op
,
l0
);
if
(
shrink_axis_mask
==
0
)
return
l1
;
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
if
(((
shrink_axis_mask
>>
i
)
&
bitwise_compare
)
==
1
)
squeeze_axes
.
push_back
(
i
);
}
if
(
num_axes
>=
4
)
{
squeeze_axes
=
parse_axes
(
squeeze_axes
);
}
auto
l0
=
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l1
);
}
instruction_ref
parse_transpose
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
perm
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
op
::
transpose
op
;
op
.
dims
=
std
::
vector
<
int64_t
>
(
perm
.
begin
(),
perm
.
end
());
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
void
parse_graph
(
const
tensorflow
::
GraphDef
&
graph
)
...
...
@@ -656,7 +920,7 @@ struct tf_parser
reorder_data
(
dims
);
}
shape
s
=
shape
{
shape_type
,
dims
};
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
to_nhwc
(
prog
.
add_parameter
(
name
,
s
)
)
;
}
for
(
auto
&&
p
:
nodes
)
{
...
...
@@ -669,10 +933,16 @@ struct tf_parser
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
// assert ops ignored
if
(
node
.
op
()
==
"Assert"
or
contains
(
name
,
"Assert"
))
return
;
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
{
// control dependencies (signified by ^ before the name) are ignored
if
(
contains
(
input
,
"^"
))
continue
;
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
...
...
@@ -732,72 +1002,56 @@ struct tf_parser
shape
::
type_t
shape_type
{};
switch
(
t
)
{
case
tensorflow
::
DataType
::
DT_INVALID
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
tensorflow
::
DataType
::
DT_FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
tensorflow
::
DataType
::
DT_INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
tensorflow
::
DataType
::
DT_HALF
:
shape_type
=
shape
::
half_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
tensorflow
::
DataType
::
DT_INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
tensorflow
::
DataType
::
DT_BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
tensorflow
::
DataType
::
DT_QINT8
:
break
;
// throw std::runtime_error("Unsupported type QINT8");
case
tensorflow
::
DataType
::
DT_QUINT8
:
break
;
// throw std::runtime_error("Unsupported type QUINT8");
case
tensorflow
::
DataType
::
DT_QINT32
:
break
;
// throw std::runtime_error("Unsupported type QINT32");
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
break
;
// throw std::runtime_error("Unsupported type BFLOAT16");
case
tensorflow
::
DataType
::
DT_QINT16
:
break
;
// throw std::runtime_error("Unsupported type QINT16");
case
tensorflow
::
DataType
::
DT_QUINT16
:
break
;
// throw std::runtime_error("Unsupported type QUINT16");
case
tensorflow
::
DataType
::
DT_UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
case
tensorflow
::
DataType
::
DT_HALF
:
shape_type
=
shape
::
half_type
;
break
;
case
tensorflow
::
DataType
::
DT_RESOURCE
:
break
;
// throw std::runtime_error("Unsupported type RESOURCE");
case
tensorflow
::
DataType
::
DT_VARIANT
:
break
;
// throw std::runtime_error("Unsupported type VARIANT");
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
// tf pb should not use these types
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_STRING_REF
:
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
break
;
case
tensorflow
::
DataType
::
DT_HALF_REF
:
break
;
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
break
;
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
break
;
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
break
;
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
break
;
}
return
shape_type
;
...
...
@@ -812,61 +1066,59 @@ struct tf_parser
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_
UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_
BOOL
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
uint16_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int16_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_HALF_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_UINT32
:
case
tensorflow
::
DataType
::
DT_UINT64
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
}
...
...
@@ -874,11 +1126,9 @@ struct tf_parser
}
switch
(
t
.
dtype
())
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
return
create_literal
(
shape
::
int8_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT16
:
...
...
@@ -890,7 +1140,6 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
get_data_vals
(
t
.
int64_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
bool_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_HALF
:
...
...
@@ -906,43 +1155,45 @@ struct tf_parser
}
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_HALF_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_UINT32
:
case
tensorflow
::
DataType
::
DT_UINT64
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
}
...
...
@@ -1006,6 +1257,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else
parser
.
parse_from
(
input
);
#endif
parser
.
to_nchw
(
std
::
prev
(
parser
.
prog
.
end
()));
return
std
::
move
(
parser
.
prog
);
}
...
...
test/CMakeLists.txt
View file @
20b1d690
...
...
@@ -119,7 +119,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set
(
TEST_NAME test_
${
BASE_NAME
}
)
add_executable
(
${
TEST_NAME
}
${
TES_ONNX_DIR
}
/
${
ONNX_TEST
}
)
rocm_clang_tidy_check
(
${
TEST_NAME
}
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
)
target_link_libraries
(
${
TEST_NAME
}
migraphx_onnx
migraphx_cpu
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/onnx
)
add_dependencies
(
tests
${
TEST_NAME
}
)
...
...
@@ -129,7 +129,7 @@ endforeach()
# tf test
add_executable
(
test_tf tf/tf_test.cpp
)
rocm_clang_tidy_check
(
test_tf
)
target_link_libraries
(
test_tf migraphx_tf
)
target_link_libraries
(
test_tf migraphx_tf
migraphx_cpu
)
target_include_directories
(
test_tf PUBLIC include
)
add_test
(
NAME test_tf COMMAND $<TARGET_FILE:test_tf> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/tf
)
add_dependencies
(
tests test_tf
)
...
...
test/cpu_dot_op_test.cpp
View file @
20b1d690
...
...
@@ -1093,4 +1093,394 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE
(
quant_dot_2args_multi4
)
{
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
8
}};
std
::
vector
<
int8_t
>
data1
(
4
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
8
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
l1
,
l2
);
std
::
vector
<
int
>
gold
=
{
112
,
118
,
124
,
130
,
136
,
142
,
148
,
154
,
304
,
326
,
348
,
370
,
392
,
414
,
436
,
458
,
496
,
534
,
572
,
610
,
648
,
686
,
724
,
762
,
688
,
742
,
796
,
850
,
904
,
958
,
1012
,
1066
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
8
}};
std
::
vector
<
int8_t
>
data1
(
4
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
8
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
tl1
,
l2
);
std
::
vector
<
int
>
gold
=
{
448
,
472
,
496
,
520
,
544
,
568
,
592
,
616
,
496
,
524
,
552
,
580
,
608
,
636
,
664
,
692
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
592
,
628
,
664
,
700
,
736
,
772
,
808
,
844
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
4
}};
std
::
vector
<
int8_t
>
data1
(
4
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
8
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
l1
,
tl2
);
std
::
vector
<
int
>
gold
=
{
14
,
38
,
62
,
86
,
110
,
134
,
158
,
182
,
38
,
126
,
214
,
302
,
390
,
478
,
566
,
654
,
62
,
214
,
366
,
518
,
670
,
822
,
974
,
1126
,
86
,
302
,
518
,
734
,
950
,
1166
,
1382
,
1598
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
4
}};
std
::
vector
<
int8_t
>
data1
(
4
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
8
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
tl1
,
tl2
);
std
::
vector
<
int
>
gold
=
{
56
,
152
,
248
,
344
,
440
,
536
,
632
,
728
,
62
,
174
,
286
,
398
,
510
,
622
,
734
,
846
,
68
,
196
,
324
,
452
,
580
,
708
,
836
,
964
,
74
,
218
,
362
,
506
,
650
,
794
,
938
,
1082
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
}
TEST_CASE
(
quant_dot_2args_general
)
{
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
5
}};
std
::
vector
<
int8_t
>
data1
(
3
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
5
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
l1
,
l2
);
std
::
vector
<
int
>
gold
=
{
70
,
76
,
82
,
88
,
94
,
190
,
212
,
234
,
256
,
278
,
310
,
348
,
386
,
424
,
462
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
5
}};
std
::
vector
<
int8_t
>
data1
(
4
*
3
);
std
::
vector
<
int8_t
>
data2
(
4
*
5
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
tl1
,
l2
);
std
::
vector
<
int
>
gold
=
{
210
,
228
,
246
,
264
,
282
,
240
,
262
,
284
,
306
,
328
,
270
,
296
,
322
,
348
,
374
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
5
,
4
}};
std
::
vector
<
int8_t
>
data1
(
3
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
5
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
2
,
},
l1
,
tl2
);
std
::
vector
<
int
>
gold
=
{
28
,
76
,
124
,
172
,
220
,
76
,
252
,
428
,
604
,
780
,
124
,
428
,
732
,
1036
,
1340
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
4
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
5
,
4
}};
std
::
vector
<
int8_t
>
data1
(
4
*
3
);
std
::
vector
<
int8_t
>
data2
(
4
*
5
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
3
,
2
},
tl1
,
tl2
);
std
::
vector
<
int
>
gold
=
{
126
,
342
,
558
,
774
,
990
,
144
,
408
,
672
,
936
,
1200
,
162
,
474
,
786
,
1098
,
1410
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
}
TEST_CASE
(
quant_dot_3args_general
)
{
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
std
::
vector
<
int8_t
>
data1
(
2
*
8
);
std
::
vector
<
int8_t
>
data2
(
8
*
7
);
std
::
vector
<
int
>
data3
(
2
*
7
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
std
::
iota
(
data3
.
begin
(),
data3
.
end
(),
2
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
data3
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
l1
,
l2
,
l3
);
std
::
vector
<
int
>
gold
=
{
982
,
1011
,
1040
,
1069
,
1098
,
1127
,
1156
,
2557
,
2650
,
2743
,
2836
,
2929
,
3022
,
3115
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
std
::
vector
<
int8_t
>
data1
(
2
*
8
);
std
::
vector
<
int8_t
>
data2
(
8
*
7
);
std
::
vector
<
int
>
data3
(
2
*
7
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
std
::
iota
(
data3
.
begin
(),
data3
.
end
(),
2
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
data3
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
1
,
3
},
tl1
,
l2
,
l3
);
std
::
vector
<
int
>
gold
=
{
1966
,
2025
,
2084
,
2143
,
2202
,
2261
,
2320
,
2183
,
2250
,
2317
,
2384
,
2451
,
2518
,
2585
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
std
::
vector
<
int8_t
>
data1
(
2
*
8
);
std
::
vector
<
int8_t
>
data2
(
8
*
7
);
std
::
vector
<
int
>
data3
(
2
*
7
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
std
::
iota
(
data3
.
begin
(),
data3
.
end
(),
2
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
data3
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
2
,
3
},
l1
,
tl2
,
l3
);
std
::
vector
<
int
>
gold
=
{
286
,
737
,
1188
,
1639
,
2090
,
2541
,
2992
,
755
,
2230
,
3705
,
5180
,
6655
,
8130
,
9605
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
std
::
vector
<
int8_t
>
data1
(
2
*
8
);
std
::
vector
<
int8_t
>
data2
(
8
*
7
);
std
::
vector
<
int
>
data3
(
2
*
7
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
std
::
iota
(
data3
.
begin
(),
data3
.
end
(),
2
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
data3
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
3
,
2
},
tl1
,
tl2
,
l3
);
std
::
vector
<
int
>
gold
=
{
844
,
2190
,
3536
,
4882
,
6228
,
7574
,
8920
,
942
,
2480
,
4018
,
5556
,
7094
,
8632
,
10170
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
}
TEST_CASE
(
quant_dot_3args_batch
)
{
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
2
,
2
,
4
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
2
,
4
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
2
,
7
}};
std
::
vector
<
int8_t
>
data1
(
4
*
2
*
4
);
std
::
vector
<
int8_t
>
data2
(
4
*
4
*
7
);
std
::
vector
<
int
>
data3
(
4
*
2
*
7
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
std
::
iota
(
data3
.
begin
(),
data3
.
end
(),
2
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
data3
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
1
,
2
},
l1
,
l2
,
l3
);
std
::
vector
<
int
>
gold
=
{
102
,
110
,
118
,
126
,
134
,
142
,
150
,
284
,
308
,
332
,
356
,
380
,
404
,
428
,
1530
,
1570
,
1610
,
1650
,
1690
,
1730
,
1770
,
2160
,
2216
,
2272
,
2328
,
2384
,
2440
,
2496
,
4750
,
4822
,
4894
,
4966
,
5038
,
5110
,
5182
,
5828
,
5916
,
6004
,
6092
,
6180
,
6268
,
6356
,
9762
,
9866
,
9970
,
10074
,
10178
,
10282
,
10386
,
11288
,
11408
,
11528
,
11648
,
11768
,
11888
,
12008
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
2
,
4
,
3
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
2
,
6
,
4
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
,
6
}};
std
::
vector
<
int8_t
>
data1
(
48
);
std
::
vector
<
int8_t
>
data2
(
96
);
std
::
vector
<
int
>
data3
(
72
);
std
::
iota
(
data1
.
begin
(),
data1
.
end
(),
0
);
std
::
iota
(
data2
.
begin
(),
data2
.
end
(),
0
);
std
::
iota
(
data3
.
begin
(),
data3
.
end
(),
2
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
m1_shape
,
data1
});
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
l1
);
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
m2_shape
,
data2
});
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
l2
);
auto
l3
=
p
.
add_literal
(
migraphx
::
literal
{
m3_shape
,
data3
});
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
2
,
3
},
tl1
,
tl2
,
l3
);
std
::
vector
<
int
>
gold
=
{
90
,
237
,
384
,
531
,
678
,
825
,
120
,
299
,
478
,
657
,
836
,
1015
,
150
,
361
,
572
,
783
,
994
,
1205
,
3456
,
3987
,
4518
,
5049
,
5580
,
6111
,
3678
,
4241
,
4804
,
5367
,
5930
,
6493
,
3900
,
4495
,
5090
,
5685
,
6280
,
6875
,
11430
,
12345
,
13260
,
14175
,
15090
,
16005
,
11844
,
12791
,
13738
,
14685
,
15632
,
16579
,
12258
,
13237
,
14216
,
15195
,
16174
,
17153
,
24012
,
25311
,
26610
,
27909
,
29208
,
30507
,
24618
,
25949
,
27280
,
28611
,
29942
,
31273
,
25224
,
26587
,
27950
,
29313
,
30676
,
32039
};
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/cpu_ops_test.cpp
View file @
20b1d690
...
...
@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
...
...
@@ -527,6 +528,51 @@ TEST_CASE(exp_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
erf_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
}};
auto
l
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
0.73785057
,
1.58165966
,
-
0.43597795
,
-
0.01677432
}});
p
.
add_instruction
(
migraphx
::
op
::
erf
{},
l
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.70327317
,
0.97470088
,
-
0.46247893
,
-
0.01892602
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
sqrt_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
5
}};
auto
l
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
1.02481645
,
0.85643062
,
0.03404123
,
0.92791926
,
0.10569184
}});
p
.
add_instruction
(
migraphx
::
op
::
sqrt
{},
l
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.01233218
,
0.92543537
,
0.18450265
,
0.96328566
,
0.32510282
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
sign_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
5
}};
auto
l
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
1.02481645
,
0.85643062
,
-
0.03404123
,
-
0.92791926
,
0.0
}});
p
.
add_instruction
(
migraphx
::
op
::
sign
{},
l
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
1.0
,
-
1.0
,
-
1.0
,
0.0
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
log_test
)
{
migraphx
::
program
p
;
...
...
@@ -541,6 +587,21 @@ TEST_CASE(log_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
pow_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
b
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
1
,
2
,
3
}});
auto
e
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
1
,
2
,
3
}});
p
.
add_instruction
(
migraphx
::
op
::
pow
{},
b
,
e
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
f
,
4.0
f
,
27.0
f
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
sin_test
)
{
migraphx
::
program
p
;
...
...
@@ -929,6 +990,21 @@ TEST_CASE(maxpool_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
c
));
}
TEST_CASE
(
softmax_simple_test
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
a
=
{
0.25
,
0.75
};
std
::
vector
<
float
>
s
=
{
0.377541
,
0.622459
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
p
.
add_instruction
(
migraphx
::
op
::
softmax
{
1
},
al
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
(
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
softmax_test
)
{
migraphx
::
program
p
;
...
...
@@ -1002,14 +1078,13 @@ TEST_CASE(logsoftmax_test_axis_0)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
-
2.71138556
,
-
5.85030702
,
-
3.74063578
,
-
4.22915517
,
-
6.15821977
,
-
5.96072346
,
-
3.57208097
,
-
5.78313166
,
-
5.51435497
,
-
3.67224195
,
-
3.88393048
,
-
2.57061599
,
-
5.54431083
,
-
6.27880025
,
-
5.1878749
,
-
6.1318955
,
-
5.29178545
,
-
4.22537886
,
-
3.75693516
,
-
7.07047099
,
-
4.45763333
,
-
4.66281846
,
-
6.18290503
,
-
4.11886536
,
-
6.17408292
,
-
4.18030052
,
-
4.64570814
,
-
4.64354473
,
-
3.06629525
,
-
3.80807681
,
-
4.69162374
,
-
5.53605222
,
-
3.20969275
,
-
4.82645674
,
-
6.63942356
,
-
4.73634471
,
-
3.86003866
,
-
5.32738981
,
-
4.22249802
,
-
4.51258693
,
-
2.41455206
,
-
3.48343199
,
-
5.86215889
,
-
4.93435935
,
-
4.83713408
,
-
2.97471885
,
-
2.16666459
,
-
3.69133151
,
-
4.71640968
,
-
5.64652924
,
-
3.60709827
,
-
5.87967748
,
-
3.8809403
,
-
4.33917815
};
-
0.135261
,
-
2.843968
,
-
0.659995
,
-
0.488413
,
-
1.051857
,
-
2.812936
,
-
0.250956
,
-
0.353985
,
-
1.155980
,
-
0.603651
,
-
0.211969
,
-
0.175371
,
-
1.336552
,
-
3.885010
,
-
1.871544
,
-
0.837083
,
-
0.887745
,
-
0.433338
,
-
1.158864
,
-
4.911197
,
-
1.147972
,
-
0.666711
,
-
0.996874
,
-
0.981418
,
-
0.851145
,
-
0.853988
,
-
0.858112
,
-
2.067420
,
-
0.059956
,
-
0.727436
,
-
0.950881
,
-
0.429689
,
-
0.061906
,
-
1.505332
,
-
1.210277
,
-
0.377970
,
-
0.791448
,
-
1.655428
,
-
1.827253
,
-
0.304828
,
-
0.020762
,
-
0.167101
,
-
0.567346
,
-
0.530319
,
-
1.045094
,
-
0.376648
,
-
0.007391
,
-
0.381670
,
-
0.720302
,
-
0.460499
,
-
0.469651
,
-
0.556740
,
-
0.554628
,
-
0.551582
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
...
...
@@ -1036,14 +1111,13 @@ TEST_CASE(logsoftmax_test_axis_1)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
-
1.77931988
,
-
4.91824134
,
-
2.80857010
,
-
3.29708949
,
-
5.22615409
,
-
5.02865778
,
-
2.64001529
,
-
4.85106598
,
-
4.58228929
,
-
2.74017627
,
-
2.95186480
,
-
1.63855031
,
-
4.61224515
,
-
5.34673457
,
-
4.25580922
,
-
5.19982982
,
-
4.35971977
,
-
3.29331318
,
-
2.82486948
,
-
6.13840531
,
-
3.52556765
,
-
3.73075278
,
-
5.25083935
,
-
3.18679968
,
-
5.24201724
,
-
3.24823484
,
-
3.71364246
,
-
4.14309917
,
-
2.56584969
,
-
3.30763125
,
-
4.19117818
,
-
5.03560666
,
-
2.70924719
,
-
4.32601118
,
-
6.13897800
,
-
4.23589915
,
-
3.35959310
,
-
4.82694425
,
-
3.72205246
,
-
4.01214137
,
-
1.91410650
,
-
2.98298643
,
-
5.36171333
,
-
4.43391379
,
-
4.33668852
,
-
2.47427329
,
-
1.66621903
,
-
3.19088595
,
-
4.21596412
,
-
5.14608368
,
-
3.10665271
,
-
5.37923192
,
-
3.38049474
,
-
3.83873259
};
-
0.550468
,
-
2.132973
,
-
1.549746
,
-
0.650533
,
-
1.051529
,
-
2.248570
,
-
0.141017
,
-
2.028357
,
-
1.947730
,
-
1.511324
,
-
0.166597
,
-
0.379726
,
-
1.965689
,
-
1.172109
,
-
1.475721
,
-
2.700831
,
-
1.537011
,
-
0.658754
,
-
1.596017
,
-
3.353137
,
-
2.266743
,
-
1.084197
,
-
1.076214
,
-
0.406712
,
-
2.743019
,
-
0.425526
,
-
1.079083
,
-
2.139486
,
-
1.270584
,
-
1.024088
,
-
1.154231
,
-
3.201762
,
-
0.888957
,
-
0.532855
,
-
3.103583
,
-
1.221339
,
-
1.355980
,
-
3.531678
,
-
1.438510
,
-
0.975194
,
-
0.080261
,
-
1.162697
,
-
1.568557
,
-
1.398519
,
-
1.322129
,
-
0.470660
,
-
0.370953
,
-
0.907343
,
-
1.179017
,
-
3.312239
,
-
1.286363
,
-
1.586076
,
-
0.345100
,
-
0.824173
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
...
...
@@ -1070,14 +1144,13 @@ TEST_CASE(logsoftmax_test_axis_2)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
-
0.79763715
,
-
3.93655861
,
-
1.82688737
,
-
2.31540676
,
-
4.24447136
,
-
4.04697505
,
-
1.65833256
,
-
3.86938325
,
-
3.60060656
,
-
1.81223672
,
-
2.02392525
,
-
0.71061076
,
-
3.68430560
,
-
4.41879502
,
-
3.32786967
,
-
4.27189027
,
-
3.43178022
,
-
2.36537363
,
-
1.35498658
,
-
4.66852241
,
-
2.05568475
,
-
2.26086988
,
-
3.78095645
,
-
1.71691678
,
-
3.77213434
,
-
1.77835194
,
-
2.24375956
,
-
2.74631770
,
-
1.16906822
,
-
1.91084978
,
-
2.79439671
,
-
3.63882519
,
-
1.31246572
,
-
2.92922971
,
-
4.74219653
,
-
2.83911768
,
-
2.19738500
,
-
3.66473615
,
-
2.55984436
,
-
2.84993327
,
-
0.75189840
,
-
1.82077833
,
-
4.19950523
,
-
3.27170569
,
-
3.17448042
,
-
1.65286841
,
-
0.84481415
,
-
2.36948107
,
-
3.39455924
,
-
4.32467880
,
-
2.28524783
,
-
4.55782704
,
-
2.55908986
,
-
3.01732771
};
-
0.495957
,
-
1.031212
,
-
0.245531
,
-
2.013726
,
-
1.339125
,
-
2.465619
,
-
1.356652
,
-
0.964037
,
-
2.019250
,
-
0.214522
,
-
0.289569
,
-
0.234392
,
-
2.086591
,
-
2.684439
,
-
2.851651
,
-
2.674176
,
-
1.697424
,
-
1.889155
,
-
0.401029
,
-
3.064586
,
-
1.173030
,
-
1.306912
,
-
2.177020
,
-
0.834262
,
-
2.818177
,
-
0.174415
,
-
1.361105
,
-
1.024571
,
-
0.106766
,
-
1.167645
,
-
1.072650
,
-
2.576522
,
-
0.569261
,
-
1.207483
,
-
3.679894
,
-
2.095913
,
-
0.504264
,
-
3.039291
,
-
1.290559
,
-
1.156812
,
-
0.126453
,
-
0.551493
,
-
2.506384
,
-
2.646261
,
-
1.905195
,
-
0.206994
,
-
0.191369
,
-
0.959754
,
-
1.948685
,
-
3.671233
,
-
0.875521
,
-
3.111952
,
-
1.905644
,
-
1.6076011
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
...
...
@@ -1104,14 +1177,13 @@ TEST_CASE(logsoftmax_test_axis_3)
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
s
=
{
-
0.33690375
,
-
3.47582521
,
-
1.36615397
,
-
0.27936556
,
-
2.20843016
,
-
2.01093385
,
-
0.22551114
,
-
2.43656183
,
-
2.16778514
,
-
1.57241522
,
-
1.78410375
,
-
0.47078926
,
-
1.06745881
,
-
1.80194823
,
-
0.71102288
,
-
2.30719726
,
-
1.46708721
,
-
0.40068062
,
-
0.42698261
,
-
3.74051844
,
-
1.12768078
,
-
1.07891856
,
-
2.59900513
,
-
0.53496546
,
-
2.56139951
,
-
0.56761711
,
-
1.03302473
,
-
2.09771276
,
-
0.52046328
,
-
1.26224484
,
-
1.76322959
,
-
2.60765807
,
-
0.28129860
,
-
0.81424303
,
-
2.62720985
,
-
0.72413100
,
-
0.65570381
,
-
2.12305496
,
-
1.01816317
,
-
2.48063402
,
-
0.38259915
,
-
1.45147908
,
-
1.84310238
,
-
0.91530284
,
-
0.81807757
,
-
1.31692881
,
-
0.50887455
,
-
2.03354147
,
-
1.48767160
,
-
2.41779116
,
-
0.37836019
,
-
2.56853147
,
-
0.56979429
,
-
1.02803214
};
-
0.336904
,
-
3.475825
,
-
1.366154
,
-
0.279366
,
-
2.208430
,
-
2.010934
,
-
0.225511
,
-
2.436562
,
-
2.167785
,
-
1.572415
,
-
1.784104
,
-
0.470789
,
-
1.067459
,
-
1.801948
,
-
0.711023
,
-
2.307197
,
-
1.467087
,
-
0.400681
,
-
0.426983
,
-
3.740518
,
-
1.127681
,
-
1.078919
,
-
2.599005
,
-
0.534965
,
-
2.561400
,
-
0.567617
,
-
1.033025
,
-
2.097713
,
-
0.520463
,
-
1.262245
,
-
1.763230
,
-
2.607658
,
-
0.281299
,
-
0.814243
,
-
2.627210
,
-
0.724131
,
-
0.655704
,
-
2.123055
,
-
1.018163
,
-
2.480634
,
-
0.382599
,
-
1.451479
,
-
1.843102
,
-
0.915303
,
-
0.818078
,
-
1.316929
,
-
0.508875
,
-
2.033541
,
-
1.487672
,
-
2.417791
,
-
0.378360
,
-
2.568531
,
-
0.569794
,
-
1.028032
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
...
...
@@ -1124,38 +1196,112 @@ TEST_CASE(logsoftmax_test_axis_3)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
logsoft
max_test_
axis_4
)
TEST_CASE
(
arg
max_test_
0
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
a
=
{
1.93885877
,
-
1.20006269
,
0.90960855
,
0.42108916
,
-
1.50797544
,
-
1.31047913
,
1.07816336
,
-
1.13288733
,
-
0.86411064
,
0.97800238
,
0.76631385
,
2.07962834
,
-
0.8940665
,
-
1.62855592
,
-
0.53763057
,
-
1.48165117
,
-
0.64154112
,
0.42486547
,
0.89330917
,
-
2.42022666
,
0.192611
,
-
0.01257413
,
-
1.5326607
,
0.53137897
,
-
1.52383859
,
0.46994381
,
0.00453619
,
0.0066996
,
1.58394908
,
0.84216752
,
-
0.04137941
,
-
0.88580789
,
1.44055158
,
-
0.17621241
,
-
1.98917923
,
-
0.08610038
,
0.79020567
,
-
0.67714548
,
0.42774631
,
0.1376574
,
2.23569227
,
1.16681234
,
-
1.21191456
,
-
0.28411502
,
-
0.18688975
,
1.67552548
,
2.48357974
,
0.95891282
,
-
0.06616535
,
-
0.99628491
,
1.04314606
,
-
1.22943315
,
0.76930403
,
0.31106618
};
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
1
,
0
,
1
,
0
,
0
,
0
,
1
,
1
,
0
,
1
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
0
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
s
=
{
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
,
0.00000000
};
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
3
,
3
}};
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
int
axis
=
4
;
p
.
add_instruction
(
migraphx
::
op
::
logsoftmax
{
axis
},
al
);
TEST_CASE
(
argmax_test_1
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
std
::
vector
<
int64_t
>
res_gold
=
{
0
,
0
,
2
,
1
,
2
,
0
,
0
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
1
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmax_test_2
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
std
::
vector
<
int64_t
>
res_gold
=
{
1
,
3
,
2
,
2
,
2
,
3
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmax
{
2
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmin_test_0
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
std
::
vector
<
int64_t
>
res_gold
=
{
1
,
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
0
,
1
,
0
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
0
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmin_test_1
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
2
,
0
,
2
,
0
,
1
,
2
,
0
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
1
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
argmin_test_2
)
{
migraphx
::
program
p
;
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
std
::
vector
<
int64_t
>
res_gold
=
{
2
,
1
,
0
,
3
,
3
,
2
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
p
.
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
p
.
add_instruction
(
migraphx
::
op
::
argmin
{
2
},
dl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold
));
}
TEST_CASE
(
conv2d_test
)
...
...
@@ -1338,6 +1484,107 @@ TEST_CASE(conv2d_padding_stride_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
quant_conv2d_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
std
::
vector
<
int8_t
>
a
(
2
*
3
*
4
*
4
);
std
::
iota
(
a
.
begin
(),
a
.
end
(),
0
);
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
std
::
vector
<
int8_t
>
c
(
2
*
3
*
3
*
3
);
std
::
iota
(
c
.
begin
(),
c
.
end
(),
0
);
auto
cl
=
p
.
add_literal
(
migraphx
::
literal
{
c_shape
,
c
});
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{},
al
,
cl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int32_t
>
s
=
{
10197
,
10548
,
11601
,
11952
,
25506
,
26586
,
29826
,
30906
,
27045
,
27396
,
28449
,
28800
,
77346
,
78426
,
81666
,
82746
};
std
::
vector
<
int32_t
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
quant_conv2d_padding_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
std
::
vector
<
int8_t
>
a
(
2
*
3
*
4
*
4
);
std
::
iota
(
a
.
begin
(),
a
.
end
(),
0
);
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
std
::
vector
<
int8_t
>
c
(
2
*
3
*
3
*
3
);
std
::
iota
(
c
.
begin
(),
c
.
end
(),
0
);
auto
cl
=
p
.
add_literal
(
migraphx
::
literal
{
c_shape
,
c
});
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
1
,
1
}},
{{
1
,
1
}}},
al
,
cl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int32_t
>
s
=
{
4521
,
6753
,
7014
,
4635
,
6858
,
10197
,
10548
,
6939
,
7830
,
11601
,
11952
,
7839
,
5007
,
7383
,
7590
,
4953
,
10515
,
15987
,
16734
,
11277
,
16821
,
25506
,
26586
,
17874
,
19737
,
29826
,
30906
,
20718
,
13593
,
20505
,
21198
,
14187
,
13161
,
19281
,
19542
,
12699
,
18522
,
27045
,
27396
,
17739
,
19494
,
28449
,
28800
,
18639
,
11919
,
17319
,
17526
,
11289
,
34707
,
51843
,
52590
,
34893
,
51813
,
77346
,
78426
,
52002
,
54729
,
81666
,
82746
,
54846
,
36057
,
53769
,
54462
,
36075
};
std
::
vector
<
int32_t
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
quant_conv2d_padding_stride_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
std
::
vector
<
int8_t
>
a
(
2
*
3
*
4
*
4
);
std
::
iota
(
a
.
begin
(),
a
.
end
(),
0
);
auto
al
=
p
.
add_literal
(
migraphx
::
literal
{
a_shape
,
a
});
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
std
::
vector
<
int8_t
>
c
(
2
*
3
*
3
*
3
);
std
::
iota
(
c
.
begin
(),
c
.
end
(),
0
);
auto
cl
=
p
.
add_literal
(
migraphx
::
literal
{
c_shape
,
c
});
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
1
,
1
}},
{{
2
,
2
}}},
al
,
cl
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int32_t
>
s
=
{
4521
,
7014
,
7830
,
11952
,
10515
,
16734
,
19737
,
30906
,
13161
,
19542
,
19494
,
28800
,
34707
,
52590
,
54729
,
82746
};
std
::
vector
<
int32_t
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
transpose_test
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}};
...
...
@@ -1574,7 +1821,7 @@ TEST_CASE(fp32_fp16_test)
auto
test_case
=
[
&
](
std
::
vector
<
std
::
string
>&&
op_names
)
{
std
::
vector
<
float
>
gold_res
=
{
2.0
,
4.0
,
6.0
,
8.0
,
10.0
,
12.0
};
auto
p
=
create_program
();
migraphx
::
quantize
(
p
,
op_names
);
migraphx
::
quantize
_fp16
(
p
,
op_names
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
res
;
...
...
@@ -1603,4 +1850,238 @@ TEST_CASE(clip_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reduce_sum_axis0
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
0
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
15
,
18
,
21
,
24
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_sum_axis1
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
1
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
4
,
6
,
12
,
14
,
20
,
22
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_sum_axis2
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
3
,
7
,
11
,
15
,
19
,
23
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_sum_axis02
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
0
,
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
33
,
45
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_sum_axis12
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
1
,
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
10
,
26
,
42
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
rsqrt_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
l
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
4.0
,
16.0
,
64.0
}});
p
.
add_instruction
(
migraphx
::
op
::
rsqrt
{},
l
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
reduce_mean_axis1
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
2
,
3
,
6
,
7
,
10
,
11
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_mean_axis2
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1.5
f
,
3.5
f
,
5.5
f
,
7.5
f
,
9.5
f
,
11.5
f
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_mean_axis02
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
0
,
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
5.5
,
7.5
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_mean_axis12
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
,
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
2.5
f
,
6.5
f
,
10.5
f
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
reduce_mean_int
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
}};
auto
input
=
migraphx
::
literal
{
s
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
}};
auto
l0
=
p
.
add_literal
(
input
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
,
2
}},
l0
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
int
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int
>
gold
{
2
,
6
,
10
};
EXPECT
(
results_vector
==
gold
);
}
TEST_CASE
(
sqdiff_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
-
1
,
0
,
1
}});
auto
l2
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
1
,
2
,
3
}});
p
.
add_instruction
(
migraphx
::
op
::
sqdiff
{},
l1
,
l2
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
4
,
4
,
4
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
round_test
)
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
9
}};
auto
l
=
p
.
add_literal
(
migraphx
::
literal
{
s
,
{
1.1
,
1.5
,
1.6
,
-
1.1
,
-
1.5
,
-
1.6
,
0.0
,
2.0
,
-
2.0
}});
p
.
add_instruction
(
migraphx
::
op
::
round
{},
l
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
for
(
auto
v
:
results_vector
)
{
std
::
cout
<<
v
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
std
::
vector
<
float
>
gold
=
{
1.0
,
2.0
,
2.0
,
-
1.0
,
-
2.0
,
-
2.0
,
0.0
,
2.0
,
-
2.0
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
TEST_CASE
(
op_capture
)
{
migraphx
::
program
p
;
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
3
,
3
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
float
>
d1
(
s1
.
elements
());
std
::
vector
<
float
>
d2
(
s2
.
elements
());
std
::
iota
(
d1
.
begin
(),
d1
.
end
(),
0.0
f
);
std
::
iota
(
d2
.
begin
(),
d2
.
end
(),
0.0
f
);
auto
p1
=
p
.
add_literal
(
s1
,
d1
);
auto
p2
=
p
.
add_literal
(
s1
,
d1
);
auto
pb
=
p
.
add_literal
(
s2
,
d2
);
auto
pc
=
p
.
add_literal
(
s2
,
d2
);
auto
pa
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
p1
,
p2
);
auto
ps
=
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
pa
,
pb
,
pc
);
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
pa
,
ps
);
migraphx
::
program
capture_p
=
p
;
migraphx
::
target
t
=
migraphx
::
cpu
::
target
{};
migraphx
::
capture_arguments
(
capture_p
,
t
,
{
"dot"
});
p
.
compile
(
migraphx
::
cpu
::
target
{});
capture_p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
cap_res
=
capture_p
.
eval
({});
auto
res
=
p
.
eval
({});
std
::
vector
<
float
>
vec
;
std
::
vector
<
float
>
cap_vec
;
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
vec
,
cap_vec
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/common_subexpression_
elimination_
test.cpp
→
test/
eliminate_
common_subexpression_test.cpp
View file @
20b1d690
#include <migraphx/common_subexpression
_elimination
.hpp>
#include <migraphx/
eliminate_
common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp>
#include <basic_ops.hpp>
...
...
@@ -9,7 +9,7 @@ struct cse_target
std
::
string
name
()
const
{
return
"dce"
;
}
std
::
vector
<
migraphx
::
pass
>
get_passes
(
migraphx
::
context
&
)
const
{
return
{
migraphx
::
common_subexpression
_elimination
{},
migraphx
::
dead_code_elimination
{}};
return
{
migraphx
::
eliminate_
common_subexpression
{},
migraphx
::
dead_code_elimination
{}};
}
migraphx
::
context
get_context
()
const
{
return
{};
}
};
...
...
test/eliminate_contiguous_test.cpp
View file @
20b1d690
...
...
@@ -22,7 +22,7 @@ struct eliminate_contiguous_target
TEST_CASE
(
standard_op
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_
literal
(
get_2x2
()
);
auto
l
=
p
.
add_
parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}}
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_standard_op
{},
c
);
...
...
@@ -31,18 +31,40 @@ TEST_CASE(standard_op)
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
TEST_CASE
(
non_
standard_op
)
TEST_CASE
(
standard_op
_const
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_standard_op
{},
c
);
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
}
TEST_CASE
(
non_standard_op
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}});
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c
);
auto
count
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
TEST_CASE
(
non_standard_op_const
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
p
.
add_instruction
(
pass_op
{},
c
);
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
2
);
}
TEST_CASE
(
transpose_gemm
)
{
migraphx
::
program
p
;
...
...
@@ -59,7 +81,7 @@ TEST_CASE(transpose_gemm)
TEST_CASE
(
transpose_standard_op
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_
literal
(
get_2x2
()
);
auto
l
=
p
.
add_
parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}}
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
sn
=
p
.
add_instruction
(
migraphx
::
op
::
sin
{},
c
);
...
...
@@ -69,6 +91,18 @@ TEST_CASE(transpose_standard_op)
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
TEST_CASE
(
transpose_standard_op_const
)
{
migraphx
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l
);
auto
c
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
t
);
auto
sn
=
p
.
add_instruction
(
migraphx
::
op
::
sin
{},
c
);
p
.
add_instruction
(
pass_standard_op
{},
sn
);
p
.
compile
(
eliminate_contiguous_target
{});
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
3
);
}
TEST_CASE
(
no_packed_unary_op
)
{
migraphx
::
program
p
;
...
...
test/eliminate_pad_test.cpp
View file @
20b1d690
...
...
@@ -83,23 +83,4 @@ TEST_CASE(rewrite_test_asymmetric)
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
}
TEST_CASE
(
rewrite_test_same_padding
)
{
migraphx
::
program
p
;
size_t
img_dim
[
2
]
=
{
2
,
2
};
size_t
channels
=
1
;
std
::
vector
<
int32_t
>
input
(
channels
*
img_dim
[
0
]
*
img_dim
[
1
]);
std
::
iota
(
input
.
begin
(),
input
.
end
(),
0
);
migraphx
::
shape
s_img
{
migraphx
::
shape
::
int32_type
,
{
1
,
channels
,
img_dim
[
0
],
img_dim
[
1
]}};
auto
l_img
=
p
.
add_literal
(
migraphx
::
literal
{
s_img
,
input
});
auto
padded_img
=
p
.
add_instruction
(
migraphx
::
op
::
pad
{{
0
,
0
,
1
,
1
,
0
,
0
,
1
,
1
}},
l_img
);
create_conv
(
padded_img
,
channels
,
p
,
migraphx
::
op
::
padding_mode_t
::
same
);
p
.
compile
(
eliminate_pad_target
{});
EXPECT
(
std
::
any_of
(
p
.
begin
(),
p
.
end
(),
[](
const
migraphx
::
instruction
&
ins
)
{
return
ins
.
name
()
==
"pad"
;
}));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/gpu/
miopen
.cpp
→
test/gpu/
ops_test
.cpp
View file @
20b1d690
...
...
@@ -243,6 +243,43 @@ struct test_exp : verify_program<test_exp>
}
};
struct
test_erf
:
verify_program
<
test_erf
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
6
}};
auto
param
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
erf
{},
param
);
return
p
;
}
};
struct
test_sqrt
:
verify_program
<
test_sqrt
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
6
}};
auto
param
=
p
.
add_parameter
(
"x"
,
s
);
auto
param_abs
=
p
.
add_instruction
(
migraphx
::
op
::
abs
{},
param
);
p
.
add_instruction
(
migraphx
::
op
::
sqrt
{},
param_abs
);
return
p
;
}
};
struct
test_sign
:
verify_program
<
test_sign
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
double_type
,
{
2
,
3
,
4
,
6
}};
auto
param
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
sign
{},
param
);
return
p
;
}
};
struct
test_log
:
verify_program
<
test_log
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -255,6 +292,20 @@ struct test_log : verify_program<test_log>
}
};
struct
test_pow
:
verify_program
<
test_pow
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
6
}};
std
::
vector
<
float
>
vec_e
(
s
.
elements
(),
2.0
f
);
auto
b
=
p
.
add_parameter
(
"x"
,
s
);
auto
e
=
p
.
add_literal
(
migraphx
::
literal
(
s
,
vec_e
));
p
.
add_instruction
(
migraphx
::
op
::
pow
{},
b
,
e
);
return
p
;
}
};
struct
test_sin
:
verify_program
<
test_sin
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -451,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
}
};
struct
test_mul_add
:
verify_program
<
test_mul_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
a
=
p
.
add_parameter
(
"a"
,
bs
);
auto
b
=
p
.
add_parameter
(
"b"
,
bs
);
auto
ab
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
a
);
auto
bb
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
b
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
x
,
ab
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
mul
,
bb
);
return
p
;
}
};
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -569,13 +638,45 @@ struct test_sub2 : verify_program<test_sub2>
}
};
struct
test_
softmax
:
verify_program
<
test_
softmax
>
struct
test_
div
:
verify_program
<
test_
div
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
5
,
3
,
4
,
2
}});
p
.
add_instruction
(
migraphx
::
op
::
softmax
{},
x
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
z
=
p
.
add_parameter
(
"z"
,
s
);
auto
diff
=
p
.
add_instruction
(
migraphx
::
op
::
div
{},
x
,
y
);
p
.
add_instruction
(
migraphx
::
op
::
div
{},
diff
,
z
);
return
p
;
}
};
struct
test_div2
:
verify_program
<
test_div2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
b
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
z
=
p
.
add_parameter
(
"z"
,
b
);
auto
zb
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
z
);
auto
diff
=
p
.
add_instruction
(
migraphx
::
op
::
div
{},
x
,
y
);
p
.
add_instruction
(
migraphx
::
op
::
div
{},
diff
,
zb
);
return
p
;
}
};
struct
test_softmax1
:
verify_program
<
test_softmax1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
5
,
3
,
3
,
4
}});
p
.
add_instruction
(
migraphx
::
op
::
softmax
{
0
},
x
);
return
p
;
}
};
...
...
@@ -592,6 +693,53 @@ struct test_softmax2 : verify_program<test_softmax2>
}
};
template
<
int
Axis
,
migraphx
::
shape
::
type_t
T
>
struct
test_softmax
:
verify_program
<
test_softmax
<
Axis
,
T
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
T
,
{
512
,
4
,
1067
,
6
}};
auto
param
=
p
.
add_parameter
(
"0"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
softmax
{
Axis
},
param
);
return
p
;
}
};
template
struct
test_softmax
<
0
,
migraphx
::
shape
::
float_type
>;
template
struct
test_softmax
<
2
,
migraphx
::
shape
::
float_type
>;
template
struct
test_softmax
<
1
,
migraphx
::
shape
::
double_type
>;
template
struct
test_softmax
<
3
,
migraphx
::
shape
::
double_type
>;
template
struct
test_softmax
<
0
,
migraphx
::
shape
::
half_type
>;
template
struct
test_softmax
<
1
,
migraphx
::
shape
::
half_type
>;
template
struct
test_softmax
<
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_softmax
<
3
,
migraphx
::
shape
::
half_type
>;
template
<
class
T
,
int
Axis
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1025
}};
auto
param
=
p
.
add_parameter
(
"data"
,
s
);
p
.
add_instruction
(
T
{
Axis
},
param
);
return
p
;
}
};
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
>;
struct
test_conv
:
verify_program
<
test_conv
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -679,6 +827,77 @@ struct test_add_relu : verify_program<test_add_relu>
}
};
struct
test_add_sigmoid
:
verify_program
<
test_add_sigmoid
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
add
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
p
.
add_instruction
(
migraphx
::
op
::
sigmoid
{},
add
);
return
p
;
}
};
struct
test_add_tanh
:
verify_program
<
test_add_tanh
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
add
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
p
.
add_instruction
(
migraphx
::
op
::
tanh
{},
add
);
return
p
;
}
};
struct
test_triadd_relu
:
verify_program
<
test_triadd_relu
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
z
=
p
.
add_parameter
(
"z"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
sum
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
auto
triadd
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
sum
,
z
);
p
.
add_instruction
(
migraphx
::
op
::
relu
{},
triadd
);
return
p
;
}
};
struct
test_triadd_sigmoid
:
verify_program
<
test_triadd_sigmoid
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
z
=
p
.
add_parameter
(
"z"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
sum
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
auto
triadd
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
sum
,
z
);
p
.
add_instruction
(
migraphx
::
op
::
sigmoid
{},
triadd
);
return
p
;
}
};
struct
test_triadd_tanh
:
verify_program
<
test_triadd_tanh
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
y
=
p
.
add_parameter
(
"y"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
z
=
p
.
add_parameter
(
"z"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
auto
sum
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
auto
triadd
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
sum
,
z
);
p
.
add_instruction
(
migraphx
::
op
::
tanh
{},
triadd
);
return
p
;
}
};
struct
test_sigmoid
:
verify_program
<
test_sigmoid
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -1238,6 +1457,114 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
}
};
struct
quant_dot_3args_1
:
verify_program
<
quant_dot_3args_1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
l1
=
p
.
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
p
.
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
p
.
add_parameter
(
"c"
,
m3_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
l1
,
l2
,
l3
);
return
p
;
}
};
struct
quant_dot_3args_2
:
verify_program
<
quant_dot_3args_2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
l1
=
p
.
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
p
.
add_parameter
(
"c"
,
m3_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
1
,
3
},
tl1
,
l2
,
l3
);
return
p
;
}
};
struct
quant_dot_3args_3
:
verify_program
<
quant_dot_3args_3
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
l1
=
p
.
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
p
.
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
auto
l3
=
p
.
add_parameter
(
"c"
,
m3_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
2
,
3
},
l1
,
tl2
,
l3
);
return
p
;
}
};
struct
quant_dot_3args_4
:
verify_program
<
quant_dot_3args_4
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
7
}};
auto
l1
=
p
.
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
auto
l2
=
p
.
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l2
);
auto
l3
=
p
.
add_parameter
(
"c"
,
m3_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
3
,
2
},
tl1
,
tl2
,
l3
);
return
p
;
}
};
struct
batch_quant_dot_1
:
verify_program
<
batch_quant_dot_1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
2
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
7
,
8
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
p
.
add_parameter
(
"a"
,
m1_shape
);
auto
tl1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
l1
);
auto
l2
=
p
.
add_parameter
(
"b"
,
m2_shape
);
auto
tl2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
l2
);
auto
l3
=
p
.
add_parameter
(
"c"
,
m3_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
3
,
2
},
tl1
,
tl2
,
l3
);
return
p
;
}
};
struct
batch_quant_dot_2
:
verify_program
<
batch_quant_dot_2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
2
,
8
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
int8_type
,
{
3
,
2
,
8
,
7
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
,
7
}};
auto
l1
=
p
.
add_parameter
(
"a"
,
m1_shape
);
auto
l2
=
p
.
add_parameter
(
"b"
,
m2_shape
);
auto
l3
=
p
.
add_parameter
(
"c"
,
m3_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{
1
,
3
},
l1
,
l2
,
l3
);
return
p
;
}
};
struct
test_contiguous
:
verify_program
<
test_contiguous
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -1367,6 +1694,83 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
}
};
struct
quant_conv
:
verify_program
<
quant_conv
>
{
migraphx
::
program
create_program
()
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
p
.
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
p
.
add_parameter
(
"c"
,
c_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{},
pa
,
pc
);
return
p
;
}
};
struct
quant_conv_default_mode
:
verify_program
<
quant_conv_default_mode
>
{
migraphx
::
program
create_program
()
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
p
.
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
p
.
add_parameter
(
"c"
,
c_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}},
migraphx
::
op
::
same
},
pa
,
pc
);
return
p
;
}
};
struct
quant_conv_valid_mode
:
verify_program
<
quant_conv_valid_mode
>
{
migraphx
::
program
create_program
()
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
p
.
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
p
.
add_parameter
(
"c"
,
c_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
0
,
0
}},
{{
1
,
1
}},
{{
1
,
1
}},
migraphx
::
op
::
valid
},
pa
,
pc
);
return
p
;
}
};
struct
quant_conv_padding
:
verify_program
<
quant_conv_padding
>
{
migraphx
::
program
create_program
()
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
p
.
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
p
.
add_parameter
(
"c"
,
c_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
1
,
1
}},
{{
1
,
1
}}},
pa
,
pc
);
return
p
;
}
};
struct
quant_conv_padding_stride
:
verify_program
<
quant_conv_padding_stride
>
{
migraphx
::
program
create_program
()
{
migraphx
::
program
p
;
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
4
,
4
}};
auto
pa
=
p
.
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
c_shape
{
migraphx
::
shape
::
int8_type
,
{
2
,
3
,
3
,
3
}};
auto
pc
=
p
.
add_parameter
(
"c"
,
c_shape
);
p
.
add_instruction
(
migraphx
::
op
::
quant_convolution
{{{
1
,
1
}},
{{
2
,
2
}}},
pa
,
pc
);
return
p
;
}
};
struct
test_concat
:
verify_program
<
test_concat
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -1441,6 +1845,22 @@ struct test_pad : verify_program<test_pad>
}
};
struct
test_pad_int8
:
verify_program
<
test_pad_int8
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
std
::
vector
<
int8_t
>
data0
=
{
0
,
1
,
2
,
3
};
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
auto
l0
=
p
.
add_literal
(
migraphx
::
literal
{
s0
,
data0
});
migraphx
::
op
::
pad
op
{};
op
.
value
=
std
::
numeric_limits
<
int8_t
>::
lowest
();
op
.
pads
=
{
0
,
0
,
1
,
1
};
p
.
add_instruction
(
op
,
l0
);
return
p
;
}
};
struct
test_pooling_autopad
:
verify_program
<
test_pooling_autopad
>
{
migraphx
::
program
create_program
()
const
...
...
@@ -2631,10 +3051,11 @@ struct test_lstm_forward_last : verify_program<test_lstm_forward_last>
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
output
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
migraphx
::
op
::
lstm
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
rnn_direction
::
forward
,
clip
},
seq
,
w
,
r
,
...
...
@@ -3308,33 +3729,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
}
};
template
<
int
Axis
>
struct
test_logsoftmax
:
verify_program
<
test_logsoftmax
<
Axis
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
,
6
}};
auto
param
=
p
.
add_parameter
(
"0"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
logsoftmax
{
Axis
},
param
);
return
p
;
}
};
template
struct
test_logsoftmax
<
0
>;
template
struct
test_logsoftmax
<
1
>;
template
struct
test_logsoftmax
<
2
>;
template
struct
test_logsoftmax
<
3
>;
template
struct
test_logsoftmax
<
4
>;
template
<
int
Axis
>
struct
test_logsoftmax_1
:
verify_program
<
test_logsoftmax_1
<
Axis
>>
template
<
int
Axis
,
migraphx
::
shape
::
type_t
T
>
struct
test_logsoftmax
:
verify_program
<
test_logsoftmax
<
Axis
,
T
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
migraphx
::
shape
s
{
T
,
{
10
,
4
,
2080
,
6
}};
auto
param
=
p
.
add_parameter
(
"0"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
logsoftmax
{
Axis
},
param
);
...
...
@@ -3342,8 +3743,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
}
};
template
struct
test_logsoftmax_1
<
0
>;
template
struct
test_logsoftmax_1
<
1
>;
template
struct
test_logsoftmax
<
0
,
migraphx
::
shape
::
float_type
>;
template
struct
test_logsoftmax
<
1
,
migraphx
::
shape
::
float_type
>;
template
struct
test_logsoftmax
<
2
,
migraphx
::
shape
::
float_type
>;
template
struct
test_logsoftmax
<
3
,
migraphx
::
shape
::
float_type
>;
template
struct
test_logsoftmax
<
1
,
migraphx
::
shape
::
double_type
>;
template
struct
test_logsoftmax
<
3
,
migraphx
::
shape
::
double_type
>;
template
struct
test_logsoftmax
<
1
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
0
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_logsoftmax
<
3
,
migraphx
::
shape
::
half_type
>;
struct
test_fp32_fp16_lall
:
verify_program
<
test_fp32_fp16_lall
>
{
...
...
@@ -3356,7 +3765,7 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
(
s
,
data
));
auto
l2
=
p
.
add_parameter
(
"p2"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l1
,
l2
);
migraphx
::
quantize
(
p
,
{
"all"
});
migraphx
::
quantize
_fp16
(
p
,
{
"all"
});
return
p
;
};
};
...
...
@@ -3372,7 +3781,7 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
(
s
,
data
));
auto
l2
=
p
.
add_parameter
(
"p2"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l1
,
l2
);
migraphx
::
quantize
(
p
,
{
"add"
});
migraphx
::
quantize
_fp16
(
p
,
{
"add"
});
return
p
;
};
};
...
...
@@ -3388,7 +3797,7 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
auto
sum
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
p1
,
p2
);
auto
diff
=
p
.
add_instruction
(
migraphx
::
op
::
sub
{},
sum
,
p2
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
diff
,
p1
);
migraphx
::
quantize
(
p
,
{
"add"
});
migraphx
::
quantize
_fp16
(
p
,
{
"add"
});
return
p
;
};
...
...
@@ -3405,7 +3814,134 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
auto
sum
=
p
.
add_instruction
(
migraphx
::
op
::
add
{},
p1
,
p2
);
auto
diff
=
p
.
add_instruction
(
migraphx
::
op
::
sub
{},
sum
,
p2
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
diff
,
p1
);
migraphx
::
quantize
(
p
,
{
"sub"
});
migraphx
::
quantize_fp16
(
p
,
{
"sub"
});
return
p
;
};
};
struct
test_reduce_sum
:
verify_program
<
test_reduce_sum
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
1026
,
4
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
1
}},
x
);
return
p
;
};
};
struct
test_reduce_sum_int
:
verify_program
<
test_reduce_sum_int
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
int32_type
,
{
3
,
4
,
8
,
8
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
1
}},
x
);
return
p
;
};
};
struct
test_reduce_sum_half
:
verify_program
<
test_reduce_sum_half
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
3
,
4
,
8
,
8
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_sum
{{
1
}},
x
);
return
p
;
};
};
struct
test_rsqrt
:
verify_program
<
test_rsqrt
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
l0
=
p
.
add_instruction
(
migraphx
::
op
::
clip
{
std
::
numeric_limits
<
float
>::
max
(),
1.0
},
x
);
p
.
add_instruction
(
migraphx
::
op
::
rsqrt
{},
l0
);
return
p
;
};
};
struct
test_reduce_mean
:
verify_program
<
test_reduce_mean
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
,
9
,
4
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
}},
x
);
return
p
;
};
};
struct
test_reduce_mean2
:
verify_program
<
test_reduce_mean2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
128
,
768
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
2
}},
x
);
return
p
;
};
};
struct
test_reduce_mean_int
:
verify_program
<
test_reduce_mean_int
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
int32_type
,
{
3
,
1024
,
8
,
8
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
}},
x
);
return
p
;
};
};
struct
test_reduce_mean_half
:
verify_program
<
test_reduce_mean_half
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
3
,
1024
,
8
,
8
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
2
}},
x
);
return
p
;
};
};
struct
test_round
:
verify_program
<
test_round
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
6
}};
auto
param
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraphx
::
op
::
round
{},
param
);
return
p
;
};
};
struct
test_convert
:
verify_program
<
test_convert
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
8
,
24
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
24
,
6
}};
auto
pa
=
p
.
add_parameter
(
"a"
,
sa
);
auto
pb
=
p
.
add_parameter
(
"b"
,
sb
);
auto
ia
=
p
.
add_instruction
(
migraphx
::
op
::
convert
{
migraphx
::
shape
::
int8_type
},
pa
);
auto
ib
=
p
.
add_instruction
(
migraphx
::
op
::
convert
{
migraphx
::
shape
::
int8_type
},
pb
);
p
.
add_instruction
(
migraphx
::
op
::
quant_dot
{},
ia
,
ib
);
return
p
;
};
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment