Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
da78b0c0
Commit
da78b0c0
authored
May 01, 2023
by
Paul
Browse files
Format
parent
7f0c6da9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
13 deletions
+21
-13
src/common_dims.cpp
src/common_dims.cpp
+21
-13
No files found.
src/common_dims.cpp
View file @
da78b0c0
...
...
@@ -28,7 +28,11 @@ static auto elements(const Range& r)
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{}
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
...
...
@@ -50,13 +54,15 @@ struct common_dim_state
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
...
...
@@ -67,7 +73,9 @@ struct common_dim_state
}
};
static
bool
commpute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
static
bool
commpute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
...
...
@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
{
if
(
commpute_common_dim
(
cd
.
dims
,
state1
,
state2
))
if
(
commpute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
// if(d1 > d2)
{
if
(
commpute_common_dim
(
cd
.
dims
,
state2
,
state1
))
if
(
commpute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment