Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f0c6da9
Commit
7f0c6da9
authored
May 01, 2023
by
Paul
Browse files
Fix tests
parent
9acc5aad
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
48 deletions
+94
-48
src/common_dims.cpp
src/common_dims.cpp
+71
-48
src/include/migraphx/ranges.hpp
src/include/migraphx/ranges.hpp
+3
-0
test/common_dims.cpp
test/common_dims.cpp
+20
-0
No files found.
src/common_dims.cpp
View file @
7f0c6da9
...
...
@@ -13,31 +13,27 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>
=
dim
;
return
x
>
dim
;
});
if
(
x
!=
dim
)
if
(
x
<
dim
)
return
start
;
return
it
;
}
template
<
class
Iterator
>
static
auto
elements
(
Iterator
start
,
Iterator
last
)
{
return
std
::
accumulate
(
start
,
last
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
elements
(
r
.
begin
(),
r
.
end
());
return
std
::
accumulate
(
r
.
begin
(),
r
.
end
()
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{}
);
}
struct
common_dim_state
{
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
)
:
dims
(
&
pdims
),
it
(
dims
->
begin
())
{}
const
std
::
vector
<
std
::
size_t
>*
dims
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
;
common_dim_state
(
const
std
::
vector
<
std
::
size_t
>&
pdims
,
std
::
vector
<
std
::
vector
<
std
::
size_t
>>&
paxes_map
)
:
dims
(
&
pdims
),
axes_map
(
&
paxes_map
),
it
(
dims
->
begin
())
{}
const
std
::
vector
<
std
::
size_t
>*
dims
=
nullptr
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>*
axes_map
=
nullptr
;
std
::
vector
<
std
::
size_t
>::
const_iterator
it
{};
std
::
size_t
rem
=
1
;
std
::
size_t
get
()
const
{
return
*
it
;
}
std
::
size_t
get
()
const
{
return
*
it
/
rem
;
}
bool
is_end
()
const
{
return
it
==
dims
->
end
();
}
void
next
(
std
::
size_t
i
=
1
)
{
it
+=
i
;
}
auto
dims_for
(
std
::
size_t
d
)
const
...
...
@@ -45,53 +41,80 @@ struct common_dim_state
auto
dim_end
=
compute_end_dim
(
it
,
dims
->
end
(),
d
);
return
range
(
it
,
dim_end
);
}
void
add_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
{
auto
axes
=
compute_axes
(
naxes
,
start
);
axes_map
->
push_back
(
std
::
move
(
axes
));
}
void
add_multi_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
{
auto
axes
=
compute_axes
(
naxes
,
start
);
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
*
axes_map
),
[
&
](
auto
axis
)
->
std
::
vector
<
std
::
size_t
>
{
return
{
axis
};
});
}
std
::
vector
<
std
::
size_t
>
compute_axes
(
std
::
size_t
naxes
,
std
::
size_t
start
)
const
{
if
(
rem
!=
1
)
{
assert
(
start
>
0
);
naxes
++
;
start
--
;
}
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
start
);
return
axes
;
}
};
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
static
bool
commpute_common_dim
(
std
::
vector
<
std
::
size_t
>&
cd_dims
,
common_dim_state
&
state1
,
common_dim_state
&
state2
)
{
assert
(
state1
.
get
()
<=
state2
.
get
());
auto
d2
=
state2
.
get
();
auto
dims
=
state1
.
dims_for
(
d2
);
auto
n
=
elements
(
dims
);
auto
naxes
=
distance
(
dims
);
// If not divisible then we can't compute a common dim
if
((
d2
%
n
)
!=
0
)
return
true
;
auto
rem
=
d2
/
n
;
state1
.
add_multi_axes
(
naxes
,
cd_dims
.
size
());
state2
.
add_axes
(
rem
!=
1
?
naxes
+
1
:
naxes
,
cd_dims
.
size
());
state1
.
rem
=
rem
;
state2
.
rem
=
1
;
cd_dims
.
insert
(
cd_dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
state1
.
rem
!=
1
)
cd_dims
.
push_back
(
state1
.
rem
);
state1
.
next
(
distance
(
dims
));
state2
.
next
();
return
false
;
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
auto
it1
=
dims1
.
begin
();
auto
it2
=
dims2
.
begin
();
std
::
size_t
rem1
=
1
;
std
::
size_t
rem2
=
1
;
while
(
it1
!=
dims1
.
end
()
and
it2
!=
dims2
.
end
())
common_dim_state
state1
{
dims1
,
cd
.
axes_map1
};
common_dim_state
state2
{
dims2
,
cd
.
axes_map2
};
while
(
not
state1
.
is_end
()
and
not
state2
.
is_end
())
{
auto
d1
=
*
it1
;
auto
d2
=
*
it2
;
if
(
d1
=
=
d2
)
auto
d1
=
state1
.
get
()
;
auto
d2
=
state2
.
get
()
;
if
(
d1
<
=
d2
)
{
cd
.
axes_map1
.
push_back
({
cd
.
dims
.
size
()});
cd
.
axes_map2
.
push_back
({
cd
.
dims
.
size
()});
cd
.
dims
.
push_back
(
d1
);
it1
++
;
it2
++
;
if
(
commpute_common_dim
(
cd
.
dims
,
state1
,
state2
))
return
{};
}
else
if
(
d1
<
d2
)
{
auto
dim_end
=
compute_end_dim
(
it1
,
dims1
.
begin
(),
d2
);
auto
dims
=
range
(
it1
,
dim_end
);
auto
n
=
elements
(
dims
);
if
(
n
!=
d2
)
else
// if(d1 > d2)
{
// If not divisible then we can't compute a common dims
if
((
d2
%
n
)
!=
0
)
if
(
commpute_common_dim
(
cd
.
dims
,
state2
,
state1
))
return
{};
rem1
=
d2
/
n
;
}
std
::
vector
<
std
::
size_t
>
axes
(
distance
(
dims
));
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
cd
.
dims
.
size
());
cd
.
axes_map1
.
push_back
(
axes
);
cd
.
axes_map2
.
push_back
(
axes
);
cd
.
dims
.
insert
(
cd
.
dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
rem1
!=
1
)
cd
.
dims
.
push_back
(
rem1
);
it1
+=
distance
(
dims
);
it2
++
;
}
}
assert
(
elements
(
dims1
)
==
elements
(
cd
.
dims
));
return
cd
;
}
...
...
src/include/migraphx/ranges.hpp
View file @
7f0c6da9
...
...
@@ -248,6 +248,9 @@ struct iterator_range
Iterator
begin
()
const
{
return
start
;
}
Iterator
end
()
const
{
return
last
;
}
bool
empty
()
const
{
return
start
==
last
;
}
decltype
(
auto
)
front
()
const
{
return
*
start
;
}
};
template
<
class
Iterator
,
MIGRAPHX_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
...
...
test/common_dims.cpp
View file @
7f0c6da9
#include <migraphx/common_dims.hpp>
#include <test.hpp>
using
axes_map
=
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
;
TEST_CASE
(
common_d1_less
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
32
,
40
,
8
},
{
2
,
1280
,
8
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
});
EXPECT
(
cd
.
axes_map1
==
axes_map
{{
0
},
{
1
},
{
2
},
{
3
}});
EXPECT
(
cd
.
axes_map2
==
axes_map
{{
0
},
{
1
,
2
},
{
3
}});
}
TEST_CASE
(
common1
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
32
,
2560
},
{
2
,
1280
,
8
,
8
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
,
8
});
EXPECT
(
cd
.
axes_map1
==
axes_map
{{
0
},
{
1
},
{
2
,
3
,
4
}});
EXPECT
(
cd
.
axes_map2
==
axes_map
{{
0
},
{
1
,
2
},
{
3
},
{
4
}});
}
TEST_CASE
(
common2
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
1280
,
8
,
8
},
{
2
,
32
,
2560
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
,
8
});
EXPECT
(
cd
.
axes_map1
==
axes_map
{{
0
},
{
1
,
2
},
{
3
},
{
4
}});
EXPECT
(
cd
.
axes_map2
==
axes_map
{{
0
},
{
1
},
{
2
,
3
,
4
}});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment