Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
da78b0c0
Commit
da78b0c0
authored
May 01, 2023
by
Paul
Browse files
Format
parent
7f0c6da9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
13 deletions
+21
-13
src/common_dims.cpp
src/common_dims.cpp
+21
-13
No files found.
src/common_dims.cpp
View file @
da78b0c0
...
@@ -28,8 +28,12 @@ static auto elements(const Range& r)
...
@@ -28,8 +28,12 @@ static auto elements(const Range& r)
struct
common_dim_state
struct
common_dim_state
{
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{}
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{
}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
rem
=
1
;
...
@@ -50,13 +54,15 @@ struct common_dim_state
...
@@ -50,13 +54,15 @@ struct common_dim_state
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
{
{
auto
axes
=
compute_axes
(
naxes
,
start
);
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
std
::
transform
(
axes
.
begin
(),
return
{
axis
};
axes
.
end
(),
});
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
{
if
(
rem
!=
1
)
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
assert
(
start
>
0
);
naxes
++
;
naxes
++
;
start
--
;
start
--
;
...
@@ -67,12 +73,14 @@ struct common_dim_state
...
@@ -67,12 +73,14 @@ struct common_dim_state
}
}
};
};
static
bool
commpute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
static
bool
commpute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
{
assert
(
state1
.
get
()
<=
state2
.
get
());
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
auto
naxes
=
distance
(
dims
);
// If not divisible then we can't compute a common dim
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
if
((
d2
%
n
)
!=
0
)
...
@@ -80,7 +88,7 @@ static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_st
...
@@ -80,7 +88,7 @@ static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_st
auto
rem
=
d2
/
n
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
!=
1
?
naxes
+
1
:
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
!=
1
?
naxes
+
1
:
naxes
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
state2
.
rem
=
1
;
...
@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
...
@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
auto
d2
=
state2
.
get
();
auto
d2
=
state2
.
get
();
if
(
d1
<=
d2
)
if
(
d1
<=
d2
)
{
{
if
(
commpute_common_dim
(
cd
.
dims
,
state1
,
state2
))
if
(
commpute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
return
{};
}
}
else
// if(d1 > d2)
else
// if(d1 > d2)
{
{
if
(
commpute_common_dim
(
cd
.
dims
,
state2
,
state1
))
if
(
commpute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
return
{};
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment