Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
03ae8013
Commit
03ae8013
authored
May 18, 2022
by
Paul
Browse files
Fix unit test
parent
1e5f7133
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
3 deletions
+17
-3
src/reduce_dims.cpp
src/reduce_dims.cpp
+17
-3
No files found.
src/reduce_dims.cpp
View file @
03ae8013
...
@@ -17,9 +17,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -17,9 +17,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto
blen
=
s
.
lens
()[
n
+
1
];
auto
blen
=
s
.
lens
()[
n
+
1
];
if
(
astride
==
bstride
*
blen
or
alen
==
1
)
if
(
astride
==
bstride
*
blen
or
alen
==
1
)
{
new_lens
.
push_back
(
alen
*
blen
);
new_lens
.
push_back
(
alen
*
blen
);
}
}
}
if
(
new_lens
.
size
()
!=
shapes
.
size
())
if
(
new_lens
.
size
()
!=
shapes
.
size
())
return
false
;
return
false
;
...
@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return
true
;
return
true
;
}
}
void
reduce_dim1
(
std
::
vector
<
shape
>&
shapes
)
{
if
(
std
::
any_of
(
shapes
.
begin
(),
shapes
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
().
back
()
!=
1
;
}))
return
;
for
(
auto
&
s
:
shapes
)
{
auto
lens
=
s
.
lens
();
auto
strides
=
s
.
strides
();
lens
.
pop_back
();
strides
.
pop_back
();
s
=
shape
{
s
.
type
(),
lens
,
strides
};
}
}
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
{
{
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
return
n
+
1
;
return
n
+
1
;
}
}
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
...
@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
...
@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
n
=
reduce_dim_all
(
shapes
,
n
);
n
=
reduce_dim_all
(
shapes
,
n
);
reduce_dim1
(
shapes
);
}
}
std
::
vector
<
std
::
size_t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment