Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9738d01d
Commit
9738d01d
authored
Dec 09, 2022
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_pad
parents
e2b5a392
d411aa69
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
372 additions
and
183 deletions
+372
-183
src/include/migraphx/check_shapes.hpp
src/include/migraphx/check_shapes.hpp
+1
-1
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+51
-22
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+2
-2
test/op_shape_test.cpp
test/op_shape_test.cpp
+193
-140
test/ref_dot_op_test.cpp
test/ref_dot_op_test.cpp
+125
-18
No files found.
src/include/migraphx/check_shapes.hpp
View file @
9738d01d
...
...
@@ -198,7 +198,7 @@ struct check_shapes
*/
const
check_shapes
&
same_ndims
()
const
{
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
max_lens
().
size
();
}))
if
(
not
this
->
same
([](
const
shape
&
s
)
{
return
s
.
ndim
();
}))
MIGRAPHX_THROW
(
prefix
()
+
"Number of dimensions do not match"
);
return
*
this
;
}
...
...
src/include/migraphx/op/dot.hpp
View file @
9738d01d
...
...
@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -38,41 +39,69 @@ struct dot
std
::
string
name
()
const
{
return
"dot"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_type
().
has
(
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
same_type
().
same_ndims
().
has
(
2
);
const
shape
&
a
=
inputs
.
at
(
0
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
lens
().
size
()
>=
2
;
}))
if
(
not
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
s
.
ndim
()
>=
2
;
}))
{
MIGRAPHX_THROW
(
"DOT: dot only accept 2 or more dim
s operands
"
);
MIGRAPHX_THROW
(
"DOT: dot only accept
s operands with
2 or more dim
ensions
"
);
}
// only handle the case that the batch size of a and b are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
if
(
a
.
dynamic
()
or
b
.
dynamic
())
{
MIGRAPHX_THROW
(
"DOT: batch size of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
auto
s0
=
a
.
to_dynamic
();
auto
s1
=
b
.
to_dynamic
();
if
(
not
std
::
equal
(
s0
.
dyn_dims
().
rbegin
()
+
2
,
s0
.
dyn_dims
().
rend
(),
s1
.
dyn_dims
().
rbegin
()
+
2
,
s1
.
dyn_dims
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: dynamic outer dimensions of A and B mismatch: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
std
::
size_t
dim_0
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
s0
.
ndim
()
-
1
;
if
(
s0
.
dyn_dims
()[
dim_1
]
!=
s1
.
dyn_dims
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: dynamic inner dimensions do not match: {"
+
to_string_range
(
s0
.
dyn_dims
())
+
"} x {"
+
to_string_range
(
s1
.
dyn_dims
())
+
"}"
);
}
auto
out_dyn_dims
=
s0
.
dyn_dims
();
out_dyn_dims
[
dim_1
]
=
s1
.
dyn_dims
()[
dim_1
];
return
{
t
,
out_dyn_dims
};
}
std
::
size_t
dim_0
=
a
.
lens
().
size
()
-
2
;
std
::
size_t
dim_1
=
a
.
lens
().
size
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
else
{
MIGRAPHX_THROW
(
"DOT: inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
// only handle the case that all the dimensions except the last two are the same
if
(
not
std
::
equal
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
()))
{
MIGRAPHX_THROW
(
"DOT: static outer dimensions of A and B mismatch: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
std
::
size_t
dim_0
=
a
.
ndim
()
-
2
;
std
::
size_t
dim_1
=
a
.
ndim
()
-
1
;
if
(
a
.
lens
()[
dim_1
]
!=
b
.
lens
()[
dim_0
])
{
MIGRAPHX_THROW
(
"DOT: static inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
}
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
t
,
out_lens
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
=
argument
{
out
put_shape
};
argument
result
=
argument
{
dyn_out
.
com
put
ed
_shape
};
visit_all
(
result
,
args
[
0
],
args
[
1
])(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
...
...
src/targets/ref/lowering.cpp
View file @
9738d01d
...
...
@@ -383,9 +383,9 @@ struct ref_gemm
std
::
string
name
()
const
{
return
"ref::dot"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
out
put_shape
};
argument
result
{
dyn_out
.
com
put
ed
_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
],
1.0
f
,
0.0
f
);
return
result
;
...
...
test/op_shape_test.cpp
View file @
9738d01d
...
...
@@ -467,6 +467,199 @@ TEST_CASE(deconvolution_shape)
weights_3d
);
}
TEST_CASE
(
dot_ndim_error0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_ndim_error4
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
10
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_inner_error3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_mismatch_outer_error
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_2D_test3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_3D_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_3D_test_1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_3D_test2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
7
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_4D_test
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_static_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
8
,
8
,
0
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_static_mismatch_error
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
3
,
3
,
0
},
{
5
,
5
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_dyn_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{{
5
,
5
,
0
},
{
6
,
8
,
8
}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
6
,
8
,
8
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_dyn_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
5
,
5
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{{
4
,
5
,
5
},
{
6
,
8
,
8
}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
6
,
8
,
8
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_mismatch_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
},
{
5
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_mismatch_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
4
,
4
,
0
},
{
5
,
5
,
0
},
{
2
,
5
,
0
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
flatten_shape
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
,
8
}};
...
...
@@ -638,46 +831,6 @@ TEST_CASE(gather)
}
}
// 3 input arguments
TEST_CASE
(
gemm
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
10
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
2
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
}
TEST_CASE
(
get_tuple_elem_test
)
{
migraphx
::
shape
s0
{
migraphx
::
shape
::
bool_type
,
{
1
,
1
}};
...
...
@@ -1131,106 +1284,6 @@ TEST_CASE(lstm)
}
}
// 2 inputs arguments
TEST_CASE
(
matmul
)
{
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
2
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
5
,
4
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
1
,
4
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
5
,
8
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
7
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
7
}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
5
,
7
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
}
TEST_CASE
(
multibroadcast
)
{
{
...
...
test/ref_dot_op_test.cpp
View file @
9738d01d
...
...
@@ -35,7 +35,7 @@
#include <migraphx/half.hpp>
template
<
class
T
>
void
matmul
_test
()
void
dot_2d
_test
()
{
migraphx
::
program
p
;
...
...
@@ -82,11 +82,11 @@ void matmul_test()
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE_REGISTER
(
matmul
_test
<
float
>
)
TEST_CASE_REGISTER
(
matmul
_test
<
double
>
)
TEST_CASE_REGISTER
(
dot_2d
_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_2d
_test
<
double
>
)
template
<
class
T
>
void
matmul
_test
_ex
()
void
dot_4d
_test
()
{
migraphx
::
program
p
;
...
...
@@ -133,10 +133,10 @@ void matmul_test_ex()
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE_REGISTER
(
matmul
_test
_ex
<
float
>
)
TEST_CASE_REGISTER
(
matmul
_test
_ex
<
double
>
)
TEST_CASE_REGISTER
(
dot_4d
_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_4d
_test
<
double
>
)
TEST_CASE
(
matmul_mutli_dim_2
)
TEST_CASE
(
dot_3D_test
)
{
migraphx
::
program
p
;
...
...
@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_dim_2_beta
0
)
TEST_CASE
(
dot_3D_C_test
0
)
{
migraphx
::
program
p
;
...
...
@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_beta_0
)
TEST_CASE
(
dot_3D_C_test1
)
{
migraphx
::
program
p
;
...
...
@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
matmul_mutli_dim_2_3
)
TEST_CASE
(
dot_4D_test1
)
{
migraphx
::
program
p
;
...
...
@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_dim1_2_3
)
TEST_CASE
(
dot_4D_alpha_beta_test
)
{
migraphx
::
program
p
;
...
...
@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_mutli_3args
)
TEST_CASE
(
dot_4D_alpha_beta_C_test
)
{
migraphx
::
program
p
;
...
...
@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
}
TEST_CASE
(
gemm_3args
)
TEST_CASE
(
dot_2D_C_test0
)
{
{
migraphx
::
program
p
;
...
...
@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
}
}
TEST_CASE
(
matmul
_vv_inner_product
)
TEST_CASE
(
dot
_vv_inner_product
)
{
{
migraphx
::
program
p
;
...
...
@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
}
}
TEST_CASE
(
matmul
_vm
)
TEST_CASE
(
dot
_vm
)
{
{
migraphx
::
program
p
;
...
...
@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
}
}
TEST_CASE
(
matmul
_mv
)
TEST_CASE
(
dot
_mv
)
{
{
migraphx
::
program
p
;
...
...
@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
}
}
TEST_CASE
(
matmul
_mm1
)
TEST_CASE
(
dot
_mm1
)
{
{
migraphx
::
program
p
;
...
...
@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
}
}
TEST_CASE
(
matmul
_mm2
)
TEST_CASE
(
dot
_mm2
)
{
{
migraphx
::
program
p
;
...
...
@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE
(
dot_dyn_2D_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
5
,
5
,
0
}}};
auto
ap
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
5
,
3
}};
auto
bp
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
ap
,
bp
);
p
.
compile
(
migraphx
::
ref
::
target
{});
std
::
vector
<
float
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
1.18951101e-01
,
5.58777432e-01
,
-
3.21296298e-01
,
-
5.95997198e-01
,
-
5.01425721e-01
,
-
2.84606807e-01
,
-
5.73673557e-01
,
-
8.99430260e-01
,
-
4.25103093e-01
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.29650255e-01
};
migraphx
::
shape
input_fixed_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
5
}};
migraphx
::
parameter_map
params
;
params
[
"a"
]
=
migraphx
::
argument
(
input_fixed_shape
,
a
.
data
());
params
[
"b"
]
=
migraphx
::
argument
(
b_shape
,
b
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.15586437e+00
,
2.09177941e-03
,
-
1.47279677e+00
,
2.02627040e-01
,
-
6.04527691e-01
,
-
1.29885596e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE
(
dot_dyn_4D_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
,
0
},
{
1
,
1
,
0
},
{
4
,
6
,
4
},
{
5
,
5
,
0
}}};
auto
al
=
mm
->
add_parameter
(
"a"
,
a_shape
);
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
3
}};
auto
bl
=
mm
->
add_parameter
(
"b"
,
b_shape
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
al
,
bl
);
p
.
compile
(
migraphx
::
ref
::
target
{});
std
::
vector
<
float
>
a
=
{
-
0.00925222
,
0.56250403
,
0.70107397
,
0.75402161
,
-
0.505885
,
1.33628943
,
-
0.11413
,
-
0.31270559
,
1.59336732
,
-
0.19361027
,
-
0.91620867
,
0.40108416
,
-
0.06969921
,
0.68483471
,
-
0.39906632
,
-
1.66423624
,
0.69040076
,
-
1.31490171
,
-
0.11282616
,
-
0.79391814
};
std
::
vector
<
float
>
b
=
{
6.09568541e-01
,
-
6.10527007e-01
,
3.66646462e-01
,
1.18951101e-01
,
5.58777432e-01
,
-
3.21296298e-01
,
-
5.95997198e-01
,
-
5.01425721e-01
,
-
2.84606807e-01
,
-
5.73673557e-01
,
-
8.99430260e-01
,
-
4.25103093e-01
,
1.53027987e+00
,
-
3.81407415e-04
,
-
3.29650255e-01
};
migraphx
::
shape
input_fixed_shape0
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
4
,
5
}};
migraphx
::
shape
input_fixed_shape1
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
5
,
3
}};
migraphx
::
parameter_map
params
;
params
[
"a"
]
=
migraphx
::
argument
(
input_fixed_shape0
,
a
.
data
());
params
[
"b"
]
=
migraphx
::
argument
(
input_fixed_shape1
,
b
.
data
());
auto
result
=
p
.
eval
(
params
).
back
();
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
c
=
{
-
1.56327541e+00
,
-
7.09570140e-01
,
-
5.37424982e-01
,
-
2.22994831e-01
,
-
2.15586437e+00
,
2.09177941e-03
,
-
1.47279677e+00
,
2.02627040e-01
,
-
6.04527691e-01
,
-
1.29885596e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
}
TEST_CASE
(
quant_dot_2args_multi4
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment