Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b75e6aae
Unverified
Commit
b75e6aae
authored
Dec 13, 2023
by
shivadbhavsar
Committed by
GitHub
Dec 13, 2023
Browse files
Merge branch 'develop' into qdq_skip_ops
parents
c335de61
a60bdb67
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
152 additions
and
129 deletions
+152
-129
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+11
-0
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+4
-82
src/include/migraphx/op/reshape_lazy.hpp
src/include/migraphx/op/reshape_lazy.hpp
+57
-10
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+2
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+78
-37
No files found.
src/include/migraphx/functional.hpp
View file @
b75e6aae
...
@@ -27,6 +27,17 @@
...
@@ -27,6 +27,17 @@
#include <utility>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lifts_xs)>(private_lifts_xs)...))
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/reshape.hpp
View file @
b75e6aae
...
@@ -112,84 +112,6 @@ struct reshape
...
@@ -112,84 +112,6 @@ struct reshape
return
{
s0
.
type
(),
output_dyn_dims
};
return
{
s0
.
type
(),
output_dyn_dims
};
}
}
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>=
dim
;
});
if
(
x
!=
dim
)
return
start
;
return
it
;
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
// can remove previous nullopts that were sent back for the alias case
static
optional
<
shape
>
reshape_dims
(
const
shape
&
input
,
const
std
::
vector
<
std
::
size_t
>&
rdims
)
{
if
(
input
.
standard
())
return
shape
{
input
.
type
(),
rdims
};
const
auto
&
idims
=
input
.
lens
();
const
auto
&
istrides
=
input
.
strides
();
std
::
vector
<
std
::
size_t
>
rstrides
;
std
::
size_t
i
=
0
;
std
::
size_t
r
=
0
;
while
(
i
<
idims
.
size
()
and
r
<
rdims
.
size
())
{
auto
idim
=
idims
[
i
];
auto
rdim
=
rdims
[
r
];
if
(
rdim
==
idim
)
{
rstrides
.
push_back
(
istrides
[
i
]);
}
// squeeze
else
if
(
rdim
>
idim
)
{
auto
start
=
idims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
idims
.
end
(),
rdim
);
auto
n
=
it
-
start
;
assert
((
i
+
n
)
<=
istrides
.
size
());
i
+=
n
;
rstrides
.
push_back
(
istrides
[
i
]);
}
// unsqueeze
else
// if(rdim < idim)
{
auto
start
=
rdims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
rdims
.
end
(),
idim
);
auto
n
=
it
-
start
;
assert
((
r
+
n
)
<=
rdims
.
size
());
auto
stride
=
istrides
[
i
]
*
idim
;
std
::
for_each
(
start
,
it
+
1
,
[
&
](
auto
dim
)
{
stride
/=
dim
;
rstrides
.
push_back
(
stride
);
});
r
+=
n
;
}
i
++
;
r
++
;
}
// Handle trailing 1s
if
(
rstrides
.
size
()
<
rdims
.
size
()
and
not
rstrides
.
empty
())
{
auto
stride
=
rstrides
.
back
();
for
(
auto
d
:
range
(
rdims
.
begin
()
+
rstrides
.
size
(),
rdims
.
end
()))
{
(
void
)
d
;
rstrides
.
push_back
(
stride
);
}
}
return
shape
{
input
.
type
(),
rdims
,
rstrides
};
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
...
@@ -219,14 +141,14 @@ struct reshape
...
@@ -219,14 +141,14 @@ struct reshape
}
}
}
}
auto
s
=
re
shape
_dims
(
inputs
.
front
(),
rdims
)
;
auto
s
=
shape
{
inputs
.
front
()
.
type
()
,
rdims
}
;
if
(
s
->
elements
()
!=
inputs
.
front
().
elements
())
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"reshape: Wrong number of elements for reshape: reshape has "
+
MIGRAPHX_THROW
(
"reshape: Wrong number of elements for reshape: reshape has "
+
std
::
to_string
(
s
->
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
s
.
elements
())
+
" elements whereas the input has "
+
std
::
to_string
(
inputs
.
front
().
elements
()));
std
::
to_string
(
inputs
.
front
().
elements
()));
return
*
s
;
return
s
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
src/include/migraphx/op/reshape_lazy.hpp
View file @
b75e6aae
...
@@ -110,22 +110,69 @@ struct reshape_lazy
...
@@ -110,22 +110,69 @@ struct reshape_lazy
return
it
;
return
it
;
}
}
template
<
class
OptionalPair
>
static
OptionalPair
try_merge_pairs
(
OptionalPair
p2
,
OptionalPair
p1
)
{
if
(
not
p1
.
has_value
())
return
nullopt
;
if
(
not
p2
.
has_value
())
return
nullopt
;
auto
dim1
=
p1
->
first
;
auto
dim2
=
p2
->
first
;
auto
stride1
=
p1
->
second
;
auto
stride2
=
p2
->
second
;
auto
elements
=
dim1
*
dim2
;
// Transposed
if
(
stride2
>
stride1
)
return
nullopt
;
// Broadcasted check to avoid division by zero
if
(
stride2
==
0
)
{
if
(
stride1
==
0
)
return
{{
elements
,
0
}};
return
nullopt
;
}
if
(
stride1
%
stride2
!=
0
)
return
nullopt
;
auto
space
=
(
stride1
*
dim1
+
stride2
*
dim2
-
stride1
)
/
stride2
;
// Nonpacked
if
(
space
!=
elements
)
return
nullopt
;
return
{{
elements
,
stride2
}};
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
optional
<
std
::
size_t
>
merge_strides
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
if
(
dim_start
==
dim_last
)
return
nullopt
;
(
void
)
stride_start
;
// Is only used in the assert
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
auto
make_pair_optional
=
[
&
](
auto
dim
,
auto
stride
)
{
return
std
::
make_optional
(
std
::
make_pair
(
dim
,
stride
));
};
auto
dim_stride_pair
=
std
::
inner_product
(
std
::
make_reverse_iterator
(
dim_last
-
1
),
std
::
make_reverse_iterator
(
dim_start
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
make_pair_optional
(
*
std
::
prev
(
dim_last
),
*
std
::
prev
(
stride_last
)),
MIGRAPHX_LIFT
(
try_merge_pairs
),
make_pair_optional
);
if
(
not
dim_stride_pair
.
has_value
())
return
nullopt
;
return
dim_stride_pair
->
second
;
}
template
<
class
DimIterator
,
class
StrideIterator
>
template
<
class
DimIterator
,
class
StrideIterator
>
static
auto
can_strides_merge
(
DimIterator
dim_start
,
static
auto
can_strides_merge
(
DimIterator
dim_start
,
DimIterator
dim_last
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
StrideIterator
stride_last
)
{
{
assert
(
std
::
distance
(
dim_start
,
dim_last
)
==
std
::
distance
(
stride_start
,
stride_last
));
return
merge_strides
(
dim_start
,
dim_last
,
stride_start
,
stride_last
).
has_value
();
auto
cstride
=
*
std
::
prev
(
stride_last
);
return
std
::
equal
(
std
::
make_reverse_iterator
(
dim_last
),
std
::
make_reverse_iterator
(
dim_start
+
1
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
std
::
make_reverse_iterator
(
stride_start
),
[
&
](
auto
dim
,
auto
stride
)
{
cstride
*=
dim
;
return
stride
==
cstride
;
});
}
}
// This will attempt to alias the dimensions of the input shape to the lens of
// This will attempt to alias the dimensions of the input shape to the lens of
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
b75e6aae
...
@@ -26,10 +26,12 @@
...
@@ -26,10 +26,12 @@
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
// Similiar to decltype(auto) except it will propagate any substitution failures
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// Lifts an expression into a function object so it can be passed to a higher-order function
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
#define MIGRAPHX_LIFT(...) \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
[](auto&&... private_lifts_xs) MIGRAPHX_RETURNS( \
...
...
test/op_shape_test.cpp
View file @
b75e6aae
...
@@ -2682,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes)
...
@@ -2682,36 +2682,26 @@ TEST_CASE(reshape_shape_minus1_reshapes)
}
}
}
}
// This uses the permutation to compute the reshape since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshapes to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE
(
reshape_nonstandard
)
TEST_CASE
(
reshape_nonstandard
)
{
{
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
4
,
24
,
1
,
1
,
1
},
{
4
,
24
,
1
,
1
,
1
},
migraphx
::
invert_permutation
({
1
,
0
,
2
,
3
,
4
}));
migraphx
::
invert_permutation
({
1
,
0
,
2
,
3
,
4
}));
std
::
vector
<
std
::
pair
<
std
::
vector
<
std
::
size_t
>
,
std
::
vector
<
int64_t
>>>
tests
{
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
tests
{{
4
,
24
},
{{
4
,
24
},
{
1
,
0
}},
{
4
,
24
,
1
,
1
,
1
,
1
},
{{
4
,
24
,
1
,
1
,
1
,
1
},
{
1
,
0
,
2
,
3
,
4
,
5
}},
{
4
,
8
,
3
,
1
,
1
},
{{
4
,
8
,
3
,
1
,
1
},
{
2
,
0
,
1
,
3
,
4
}},
{
4
,
1
,
3
,
4
,
2
},
{{
4
,
1
,
3
,
4
,
2
},
{
4
,
0
,
1
,
2
,
3
}},
{
4
,
1
,
4
,
3
,
2
},
{{
4
,
1
,
4
,
3
,
2
},
{
4
,
0
,
1
,
2
,
3
}},
{
4
,
2
,
4
,
3
},
{{
4
,
2
,
4
,
3
},
{
3
,
0
,
1
,
2
}},
{
4
,
2
,
12
,
1
},
{{
4
,
2
,
12
,
1
},
{
2
,
0
,
1
,
3
}},
{
4
,
2
,
1
,
12
},
{{
4
,
2
,
1
,
12
},
{
3
,
0
,
1
,
2
}},
{
4
,
4
,
2
,
3
},
{{
4
,
4
,
2
,
3
},
{
3
,
0
,
1
,
2
}},
{
4
,
8
,
1
,
3
},
{{
4
,
8
,
1
,
3
},
{
3
,
0
,
1
,
2
}},
{
4
,
8
,
3
,
1
}};
{{
4
,
8
,
3
,
1
},
{
2
,
0
,
1
,
3
}}};
for
(
auto
dims
:
tests
)
for
(
const
auto
&
[
dims
,
perm
]
:
tests
)
{
{
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
output
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
};
migraphx
::
shape
::
float_type
,
dims
,
migraphx
::
invert_permutation
(
perm
));
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
}
}
}
}
...
@@ -2721,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze)
...
@@ -2721,8 +2711,7 @@ TEST_CASE(reshape_nonstandard_squeeze)
auto
input
=
migraphx
::
shape
::
from_permutation
(
auto
input
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
migraphx
::
invert_permutation
({
0
,
2
,
3
,
1
}));
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
migraphx
::
invert_permutation
({
0
,
2
,
3
,
1
}));
std
::
vector
<
std
::
size_t
>
lens
=
{
2
,
256
,
1280
};
std
::
vector
<
std
::
size_t
>
lens
=
{
2
,
256
,
1280
};
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
output
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
lens
};
migraphx
::
shape
::
float_type
,
lens
,
migraphx
::
invert_permutation
({
0
,
2
,
1
}));
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
}
}
...
@@ -2746,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error)
...
@@ -2746,52 +2735,80 @@ TEST_CASE(reshape_nonstandard_error)
}
}
}
}
TEST_CASE
(
reshape_transposed_squeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
4
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_unsqueeze1
)
TEST_CASE
(
reshape_nonpacked_unsqueeze1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
8
}
,
{
32
,
16
,
2
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
8
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_nonpacked_unsqueeze2
)
TEST_CASE
(
reshape_nonpacked_unsqueeze2
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
16
}
,
{
64
,
32
,
2
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
16
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_nonpacked_squeeze
)
TEST_CASE
(
reshape_nonpacked_squeeze
1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
},
{
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_nonpacked_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_broadcast_unsqueeze1
)
TEST_CASE
(
reshape_broadcast_unsqueeze1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
}
,
{
0
,
0
,
0
,
1
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_broadcast_unsqueeze2
)
TEST_CASE
(
reshape_broadcast_unsqueeze2
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
16
,
80
}
,
{
0
,
0
,
80
,
1
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
16
,
80
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_broadcast_squeeze
)
TEST_CASE
(
reshape_broadcast_squeeze
1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
TEST_CASE
(
reshape_broadcast_squeeze3
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
0
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_broadcast_squeeze_memlayout_change
)
TEST_CASE
(
reshape_broadcast_squeeze_memlayout_change
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
256
,
80
}
,
{
0
,
0
,
0
,
16
}
};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
256
,
80
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
...
@@ -2960,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
...
@@ -2960,6 +2977,12 @@ TEST_CASE(reshape_lazy_nonstandard_error)
}
}
}
}
TEST_CASE
(
reshape_lazy_transposed_squeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
4
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_nonpacked_unsqueeze1
)
TEST_CASE
(
reshape_lazy_nonpacked_unsqueeze1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
...
@@ -2974,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
...
@@ -2974,13 +2997,19 @@ TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_lazy_nonpacked_squeeze
)
TEST_CASE
(
reshape_lazy_nonpacked_squeeze
1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
},
{
2
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
64
},
{
2
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_lazy_nonpacked_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
32
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_unsqueeze1
)
TEST_CASE
(
reshape_lazy_broadcast_unsqueeze1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
...
@@ -2995,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
...
@@ -2995,13 +3024,25 @@ TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze
)
TEST_CASE
(
reshape_lazy_broadcast_squeeze
1
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
2
,
256
,
1280
},
{
0
,
0
,
1
}};
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
output
.
lens
()}}),
input
);
}
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze2
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
0
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze3
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
16
},
{
1
,
0
}};
throws_shape
(
migraphx
::
make_op
(
"reshape_lazy"
,
{{
"dims"
,
{
64
}}}),
input
);
}
TEST_CASE
(
reshape_lazy_broadcast_squeeze_error
)
TEST_CASE
(
reshape_lazy_broadcast_squeeze_error
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
16
,
16
,
1280
},
{
0
,
0
,
0
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment