Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
08dbaa12
"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "d976c04b350db812bf40eaa355ba332b33029f1f"
Unverified
Commit
08dbaa12
authored
Dec 12, 2022
by
Charlie Lin
Committed by
GitHub
Dec 12, 2022
Browse files
Merge branch 'develop' into dynamic_reduce
parents
4c11aeb5
b41c1f01
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
523 additions
and
194 deletions
+523
-194
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+1
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+51
-22
src/include/migraphx/op/softmax.hpp
src/include/migraphx/op/softmax.hpp
+5
-5
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+2
-2
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+6
-6
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+10
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+17
-0
test/onnx/softmax_dyn_test.onnx
test/onnx/softmax_dyn_test.onnx
+0
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+243
-140
test/ref_dot_op_test.cpp
test/ref_dot_op_test.cpp
+125
-18
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+63
-0
No files found.
src/include/migraphx/check_shapes.hpp
View file @
08dbaa12
...
...
@@ -198,7 +198,7 @@ struct check_shapes
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
ndim
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
...
...
src/include/migraphx/op/dot.hpp
View file @
08dbaa12
...
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -38,41 +39,69 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_type
().
same_ndims
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
ndim
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dim
s operands
"
);
MIGRAPHX_THROW
(
"DOT: dot only accept
s operands with
2 or more dim
ensions
"
);
}
// only handle the case that the batch size of a and b are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
if
(
a
.
dynamic
()
or
b
.
dynamic
())
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
else
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// only handle the case that all the dimensions except the last two are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: static outer dimensions of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
std
::
size_t
dim_0
=
a
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
a
.
ndim
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: static inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
=
argument
{
out
put_shape
};
argument
result
=
argument
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
...
...
src/include/migraphx/op/softmax.hpp
View file @
08dbaa12
...
...
@@ -53,15 +53,15 @@ struct softmax
std
::
string
name
()
const
{
return
"softmax"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
if
(
inputs
.
at
(
0
).
packed
())
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
[
0
];
if
(
s0
.
dynamic
()
or
s0
.
packed
())
{
return
inputs
.
at
(
0
)
;
return
s0
;
}
else
{
auto
lens
=
inputs
.
at
(
0
).
lens
();
return
{
inputs
.
at
(
0
).
type
(),
lens
};
return
{
s0
.
type
(),
s0
.
lens
()};
}
}
...
...
src/targets/gpu/compile_hip.cpp
View file @
08dbaa12
...
...
@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options
.
push_back
(
"-fno-gpu-rdc"
);
options
.
push_back
(
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
));
options
.
push_back
(
"-Wno-cuda-compat"
);
options
.
push_back
(
"--
cuda-gpu
-arch="
+
arch
);
options
.
push_back
(
"--
offload
-arch="
+
arch
);
prog
.
compile
(
options
);
return
{
prog
.
get_code_obj
()};
}
...
...
@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
}
else
if
(
is_hip_clang_compiler
())
{
params
+=
" --
cuda-gpu
-arch="
+
arch
;
params
+=
" --
offload
-arch="
+
arch
;
params
+=
" --cuda-device-only"
;
params
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
}
...
...
src/targets/ref/lowering.cpp
View file @
08dbaa12
...
...
@@ -383,9 +383,9 @@ struct ref_gemm
std
::
string
name
()
const
{
return
"ref::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
],
1.0
f
,
0.0
f
);
return
result
;
...
...
@@ -449,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
{
return
op
.
normalize_compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
auto
batch_lens
=
out
put_shape
.
lens
();
argument
result
{
dyn_out
.
com
put
ed
_shape
};
auto
batch_lens
=
dyn_out
.
com
put
ed
_shape
.
lens
();
int64_t
tuned_axis
=
tune_axis
(
args
[
0
].
get_shape
().
lens
().
size
(),
op
.
axis
,
op
.
name
());
std
::
size_t
n_dims
=
batch_lens
[
tuned_axis
];
batch_lens
[
tuned_axis
]
=
1
;
...
...
@@ -475,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
{
idx
[
tuned_axis
]
=
j
;
std
::
size_t
index
=
out
put_shape
.
index
(
idx
);
std
::
size_t
index
=
dyn_out
.
com
put
ed
_shape
.
index
(
idx
);
output
[
index
]
=
std
::
exp
(
input
[
index
]
-
batch_max
[
i
]);
}
...
...
test/onnx/gen_onnx.py
View file @
08dbaa12
...
...
@@ -5843,6 +5843,16 @@ def softmax_nonstd_input_test():
return
([
node0
,
node1
],
[
x
],
[
y
])
@
onnx_test
def
softmax_dyn_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
None
,
3
,
4
,
4
])
y
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
FLOAT
,
[
None
,
3
,
4
,
4
])
node
=
onnx
.
helper
.
make_node
(
'Softmax'
,
inputs
=
[
'0'
],
outputs
=
[
'1'
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
def
softsign_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
5
])
...
...
test/onnx/onnx_test.cpp
View file @
08dbaa12
...
...
@@ -5701,6 +5701,23 @@ TEST_CASE(softmax_nonstd_input_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
softmax_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
l0
=
mm
->
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
3
,
3
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
l0
);
mm
->
add_return
({
ret
});
migraphx
::
onnx_options
options
;
options
.
default_dyn_dim_value
=
{
1
,
4
,
0
};
auto
prog
=
migraphx
::
parse_onnx
(
"softmax_dyn_test.onnx"
,
options
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
softplus_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/softmax_dyn_test.onnx
0 → 100644
View file @
08dbaa12
File added
test/op_shape_test.cpp
View file @
08dbaa12
...
...
@@ -467,6 +467,199 @@ TEST_CASE(deconvolution_shape)
weights_3d
);
}
TEST_CASE
(
dot_ndim_error0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error4
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
10
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_outer_error
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_3D_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_3D_test_1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_3D_test2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
7
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_4D_test
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_static_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
8
,
8
,
0
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_static_mismatch_error
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
3
,
3
,
0
},
{
5
,
5
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_dyn_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{{
5
,
5
,
0
},
{
6
,
8
,
8
}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
6
,
8
,
8
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_dyn_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
5
,
5
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{{
4
,
5
,
5
},
{
6
,
8
,
8
}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
6
,
8
,
8
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_mismatch_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_mismatch_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
4
,
4
,
0
},
{
5
,
5
,
0
},
{
2
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
flatten_shape
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
,
8
}};
...
...
@@ -638,46 +831,6 @@ TEST_CASE(gather)
}
}
// 3 input arguments
TEST_CASE
(
gemm
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
10
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
}
TEST_CASE
(
get_tuple_elem_test
)
{
migraphx
::
shape
s0
{
migraphx
::
shape
::
bool_type
,
{
1
,
1
}};
...
...
@@ -1131,106 +1284,6 @@ TEST_CASE(lstm)
}
}
// 2 inputs arguments
TEST_CASE
(
matmul
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
7
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
}
TEST_CASE
(
multibroadcast
)
{
{
...
...
@@ -2173,6 +2226,56 @@ TEST_CASE(slice_shape)
TEST_CASE
(
softmax
)
{
test_softmax_variations
<
migraphx
::
op
::
softmax
>
();
}
TEST_CASE
(
softmax_dyn0
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
3
,
3
,
0
},
{
4
,
4
,
0
},
{
5
,
5
,
0
}}};
expect_shape
(
input
,
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
0
}}),
input
);
}
TEST_CASE
(
softmax_dyn1
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
,
0
},
{
3
,
3
,
0
},
{
4
,
6
,
0
},
{
5
,
8
,
6
}}};
expect_shape
(
input
,
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
0
}}),
input
);
}
TEST_CASE
(
test_argmax
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
1
,
3
,
4
,
5
}},
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
0
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
1
,
4
,
5
}},
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
1
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
1
,
5
}},
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
2
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
half_type
,
{
2
,
3
,
4
,
5
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
int64_type
,
{
2
,
3
,
4
,
1
}},
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
3
}}),
input
);
}
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}};
throws_shape
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
4
}}),
input
);
}
}
TEST_CASE
(
test_argmin
)
{
{
...
...
test/ref_dot_op_test.cpp
View file @
08dbaa12
...
...
@@ -35,7 +35,7 @@
#include <migraphx/half.hpp>
template
<
class
T
>
void
matmul
_test
()
void
dot_2d
_test
()
{
migraphx
::
program
p
;
...
...
@@ -82,11 +82,11 @@ void matmul_test()
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE_REGISTER
(
matmul
_test
<
float
>
)
TEST_CASE_REGISTER
(
matmul
_test
<
double
>
)
TEST_CASE_REGISTER
(
dot_2d
_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_2d
_test
<
double
>
)
template
<
class
T
>
void
matmul
_test
_ex
()
void
dot_4d
_test
()
{
migraphx
::
program
p
;
...
...
@@ -133,10 +133,10 @@ void matmul_test_ex()
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE_REGISTER
(
matmul
_test
_ex
<
float
>
)
TEST_CASE_REGISTER
(
matmul
_test
_ex
<
double
>
)
TEST_CASE_REGISTER
(
dot_4d
_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_4d
_test
<
double
>
)
TEST_CASE
(
matmul_mutli_dim_2
)
TEST_CASE
(
dot_3D_test
)
{
migraphx
::
program
p
;
...
...
@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_dim_2_beta
0
)
TEST_CASE
(
dot_3D_C_test
0
)
{
migraphx
::
program
p
;
...
...
@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_beta_0
)
TEST_CASE
(
dot_3D_C_test1
)
{
migraphx
::
program
p
;
...
...
@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
matmul_mutli_dim_2_3
)
TEST_CASE
(
dot_4D_test1
)
{
migraphx
::
program
p
;
...
...
@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_dim1_2_3
)
TEST_CASE
(
dot_4D_alpha_beta_test
)
{
migraphx
::
program
p
;
...
...
@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_3args
)
TEST_CASE
(
dot_4D_alpha_beta_C_test
)
{
migraphx
::
program
p
;
...
...
@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_3args
)
TEST_CASE
(
dot_2D_C_test0
)
{
{
migraphx
::
program
p
;
...
...
@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
}
}
TEST_CASE
(
matmul
_vv_inner_product
)
TEST_CASE
(
dot
_vv_inner_product
)
{
{
migraphx
::
program
p
;
...
...
@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
}
}
TEST_CASE
(
matmul
_vm
)
TEST_CASE
(
dot
_vm
)
{
{
migraphx
::
program
p
;
...
...
@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
}
}
TEST_CASE
(
matmul
_mv
)
TEST_CASE
(
dot
_mv
)
{
{
migraphx
::
program
p
;
...
...
@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
}
}
TEST_CASE
(
matmul
_mm1
)
TEST_CASE
(
dot
_mm1
)
{
{
migraphx
::
program
p
;
...
...
@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
}
}
TEST_CASE
(
matmul
_mm2
)
TEST_CASE
(
dot
_mm2
)
{
{
migraphx
::
program
p
;
...
...
@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE
(
dot_dyn_2D_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
}}};
auto
ap
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
5
,
3
}};
auto
bp
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
ap
,
bp
);
p
.
compile
(
migraphx
::
ref
::
target
{});
std
::
vector
<
float
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
1.18951101e-01
,
5.58777432e-01
,
-
3.21296298e-01
,
-
5.95997198e-01
,
-
5.01425721e-01
,
-
2.84606807e-01
,
-
5.73673557e-01
,
-
8.99430260e-01
,
-
4.25103093e-01
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.29650255e-01
};
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
parameter_map
params
;
params
[
"a"
]
=
migraphx
::
argument
(
input_fixed_shape
,
a
.
data
());
params
[
"b"
]
=
migraphx
::
argument
(
b_shape
,
b
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.15586437e+00
,
2.09177941e-03
,
-
1.47279677e+00
,
2.02627040e-01
,
-
6.04527691e-01
,
-
1.29885596e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE
(
dot_dyn_4D_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
,
0
},
{
1
,
1
,
0
},
{
4
,
6
,
4
},
{
5
,
5
,
0
}}};
auto
al
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
3
}};
auto
bl
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
al
,
bl
);
p
.
compile
(
migraphx
::
ref
::
target
{});
std
::
vector
<
float
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
1.18951101e-01
,
5.58777432e-01
,
-
3.21296298e-01
,
-
5.95997198e-01
,
-
5.01425721e-01
,
-
2.84606807e-01
,
-
5.73673557e-01
,
-
8.99430260e-01
,
-
4.25103093e-01
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.29650255e-01
};
migraphx
::
shape
input_fixed_shape0
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
input_fixed_shape1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
3
}};
migraphx
::
parameter_map
params
;
params
[
"a"
]
=
migraphx
::
argument
(
input_fixed_shape0
,
a
.
data
());
params
[
"b"
]
=
migraphx
::
argument
(
input_fixed_shape1
,
b
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.15586437e+00
,
2.09177941e-03
,
-
1.47279677e+00
,
2.02627040e-01
,
-
6.04527691e-01
,
-
1.29885596e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE
(
quant_dot_2args_multi4
)
{
{
...
...
test/ref_ops_test.cpp
View file @
08dbaa12
...
...
@@ -7215,6 +7215,69 @@ TEST_CASE(softmax_test)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
softmax_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
10
,
0
},
{
1
,
3
,
3
},
{
4
,
4
,
0
},
{
2
,
2
,
2
}}};
auto
al
=
mm
->
add_parameter
(
"a"
,
a_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
1
}}),
al
);
p
.
compile
(
migraphx
::
ref
::
target
{});
std
::
vector
<
float
>
a
=
{
-
5.61869681e-01
,
9.07827199e-01
,
1.29255986e+00
,
3.18533443e-02
,
-
1.22183852e-03
,
-
2.83830553e-01
,
-
1.03245842e+00
,
-
9.28322077e-01
,
-
8.82696748e-01
,
1.11327164e-01
,
-
9.20038462e-01
,
8.47388089e-01
,
2.51734018e-01
,
1.50563884e+00
,
2.23056650e+00
,
-
6.17576987e-02
,
-
1.00264274e-01
,
-
6.10369384e-01
,
1.17537189e+00
,
-
2.51560897e-01
,
-
8.50333512e-01
,
-
8.03578615e-01
,
-
6.51194930e-01
,
-
2.58137047e-01
,
4.65528190e-01
,
3.23284641e-02
,
-
1.54700470e+00
,
1.38096774e+00
,
5.39869189e-01
,
-
7.56884992e-01
,
1.81503093e+00
,
-
2.11269641e+00
,
1.92466557e+00
,
1.77230799e+00
,
2.21660900e+00
,
1.56777036e+00
,
-
2.08995026e-03
,
3.50566894e-01
,
-
1.15042710e+00
,
-
1.18577778e+00
,
8.90633047e-01
,
-
6.63949102e-02
,
1.44661188e+00
,
1.59215283e+00
,
-
2.56262213e-01
,
9.39079225e-01
,
4.07298543e-02
,
3.86590779e-01
,
6.09607756e-01
,
8.22331488e-01
,
-
2.82126725e-01
,
-
9.49052632e-01
,
-
4.24012303e-01
,
-
5.32990396e-01
,
-
3.18386006e+00
,
3.27092171e-01
,
-
1.33315325e+00
,
3.62459183e-01
,
3.74710828e-01
,
-
1.30302286e+00
,
1.79680198e-01
,
-
4.51832324e-01
,
4.34282750e-01
,
-
7.09520102e-01
,
6.20333970e-01
,
-
1.28712380e+00
,
2.04130828e-01
,
-
7.70607769e-01
,
1.61889160e+00
,
-
1.50951004e+00
,
-
4.10505563e-01
,
-
3.56566496e-02
,
-
1.29747534e+00
,
-
1.49967879e-01
,
7.77626812e-01
,
-
8.28408226e-02
,
2.73412596e-02
,
5.79780899e-03
,
9.87900198e-02
,
-
7.95276761e-01
,
-
1.38536084e+00
,
-
6.63573861e-01
,
3.89783204e-01
,
-
1.30670881e+00
,
-
7.62425125e-01
,
-
4.04883057e-01
,
6.24344349e-01
,
3.68128955e-01
,
-
1.01577950e+00
,
-
3.06715906e-01
,
5.67961395e-01
,
2.98198581e-01
,
-
1.63613629e+00
,
-
3.75131965e-01
,
-
6.75393403e-01
,
2.59172034e+00
,
6.75538957e-01
,
9.07939598e-02
,
1.92257717e-01
,
-
1.21592450e+00
,
-
2.73682117e-01
,
1.25232983e+00
,
-
1.39969170e+00
,
-
1.91483587e-01
,
2.57732719e-01
,
3.10056299e-01
,
1.41833842e+00
,
-
1.81386679e-01
,
3.92868072e-01
,
-
8.14771175e-01
,
2.02392387e+00
,
-
9.42091495e-02
,
-
3.77683818e-01
,
2.05638766e+00
,
2.93796062e-01
,
-
6.02131486e-01
,
2.70461679e-01
,
-
8.92358482e-01
,
1.04388881e+00
,
2.66154885e-01
};
migraphx
::
parameter_map
params
;
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
5
,
3
,
4
,
2
}};
params
[
"a"
]
=
migraphx
::
argument
(
input_fixed_shape
,
a
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
(
120
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
s
=
{
0.30191708
,
0.59879845
,
0.50029165
,
0.24915339
,
0.36823985
,
0.13190967
,
0.0349741
,
0.18750034
,
0.21905553
,
0.27000085
,
0.0547399
,
0.56318235
,
0.47422904
,
0.78964758
,
0.91381913
,
0.44601166
,
0.47902739
,
0.13120073
,
0.4449684
,
0.18766427
,
0.15753111
,
0.07844277
,
0.05120674
,
0.36648798
,
0.14637007
,
0.13152322
,
0.01560997
,
0.29065287
,
0.49196178
,
0.10550152
,
0.81890774
,
0.06369215
,
0.62972021
,
0.74931765
,
0.67285055
,
0.35034987
,
0.28612873
,
0.31931475
,
0.04220394
,
0.16093165
,
0.22390974
,
0.11915915
,
0.3115395
,
0.35899726
,
0.22190949
,
0.57518375
,
0.13888834
,
0.7753762
,
0.4642328
,
0.57055861
,
0.21954368
,
0.34515455
,
0.09486015
,
0.40631217
,
0.01842281
,
0.48770609
,
0.06652815
,
0.36023033
,
0.42343026
,
0.24226256
,
0.17348589
,
0.44066274
,
0.6865865
,
0.17296699
,
0.46923906
,
0.06921105
,
0.3570261
,
0.4125829
,
0.73165393
,
0.15302512
,
0.29499072
,
0.33932695
,
0.30852377
,
0.40762195
,
0.40170741
,
0.36259529
,
0.60848355
,
0.42618036
,
0.31721094
,
0.02960522
,
0.28256637
,
0.24389413
,
0.2725659
,
0.10663581
,
0.27622163
,
0.28264219
,
0.53652936
,
0.09476089
,
0.40890986
,
0.34848392
,
0.32572666
,
0.53076893
,
0.11529481
,
0.29117745
,
0.14625968
,
0.8756339
,
0.49818122
,
0.10656087
,
0.1813329
,
0.17664003
,
0.21410346
,
0.80408043
,
0.02315119
,
0.27155462
,
0.32804728
,
0.13268511
,
0.61795473
,
0.49703068
,
0.41696799
,
0.10175809
,
0.71028161
,
0.29929739
,
0.17377149
,
0.76075399
,
0.20071237
,
0.32632929
,
0.36892858
,
0.09416146
,
0.26656723
,
0.42914796
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
TEST_CASE
(
sqdiff_test
)
{
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment