Commit 7f0c6da9 authored by Paul's avatar Paul
Browse files

Fix tests

parent 9acc5aad
...@@ -13,31 +13,27 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim) ...@@ -13,31 +13,27 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
std::size_t x = 1; std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) { auto it = std::find_if(start, last, [&](auto i) {
x *= i; x *= i;
return x >= dim; return x > dim;
}); });
if(x != dim) if(x < dim)
return start; return start;
return it; return it;
} }
template <class Iterator>
static auto elements(Iterator start, Iterator last)
{
return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{});
}
template <class Range> template <class Range>
static auto elements(const Range& r) static auto elements(const Range& r)
{ {
return elements(r.begin(), r.end()); return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
} }
struct common_dim_state struct common_dim_state
{ {
common_dim_state(const std::vector<std::size_t>& pdims) : dims(&pdims), it(dims->begin()) {} common_dim_state(const std::vector<std::size_t>& pdims, std::vector<std::vector<std::size_t>>& paxes_map) : dims(&pdims), axes_map(&paxes_map), it(dims->begin()) {}
const std::vector<std::size_t>* dims; const std::vector<std::size_t>* dims = nullptr;
std::vector<std::size_t>::const_iterator it; std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1; std::size_t rem = 1;
std::size_t get() const { return *it; } std::size_t get() const { return *it / rem; }
bool is_end() const { return it == dims->end(); } bool is_end() const { return it == dims->end(); }
void next(std::size_t i = 1) { it += i; } void next(std::size_t i = 1) { it += i; }
auto dims_for(std::size_t d) const auto dims_for(std::size_t d) const
...@@ -45,53 +41,80 @@ struct common_dim_state ...@@ -45,53 +41,80 @@ struct common_dim_state
auto dim_end = compute_end_dim(it, dims->end(), d); auto dim_end = compute_end_dim(it, dims->end(), d);
return range(it, dim_end); return range(it, dim_end);
} }
void add_axes(std::size_t naxes, std::size_t start)
{
auto axes = compute_axes(naxes, start);
axes_map->push_back(std::move(axes));
}
void add_multi_axes(std::size_t naxes, std::size_t start)
{
auto axes = compute_axes(naxes, start);
std::transform(axes.begin(), axes.end(), std::back_inserter(*axes_map), [&](auto axis) -> std::vector<std::size_t> {
return {axis};
});
}
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{
if (rem != 1) {
assert(start > 0);
naxes++;
start--;
}
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), start);
return axes;
}
}; };
common_dims static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_state& state1, common_dim_state& state2)
common_dims::compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2) {
assert(state1.get() <= state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);
// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return true;
auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem != 1 ? naxes + 1 : naxes, cd_dims.size());
state1.rem = rem;
state2.rem = 1;
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state2.next();
return false;
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{ {
assert(elements(dims1) == elements(dims2)); assert(elements(dims1) == elements(dims2));
common_dims cd; common_dims cd;
auto it1 = dims1.begin(); common_dim_state state1{dims1, cd.axes_map1};
auto it2 = dims2.begin(); common_dim_state state2{dims2, cd.axes_map2};
std::size_t rem1 = 1; while(not state1.is_end() and not state2.is_end())
std::size_t rem2 = 1;
while(it1 != dims1.end() and it2 != dims2.end())
{ {
auto d1 = *it1; auto d1 = state1.get();
auto d2 = *it2; auto d2 = state2.get();
if(d1 == d2) if(d1 <= d2)
{ {
cd.axes_map1.push_back({cd.dims.size()}); if (commpute_common_dim(cd.dims, state1, state2))
cd.axes_map2.push_back({cd.dims.size()}); return {};
cd.dims.push_back(d1);
it1++;
it2++;
} }
else if(d1 < d2) else // if(d1 > d2)
{
auto dim_end = compute_end_dim(it1, dims1.begin(), d2);
auto dims = range(it1, dim_end);
auto n = elements(dims);
if(n != d2)
{ {
// If not divisible then we can't compute a common dims if (commpute_common_dim(cd.dims, state2, state1))
if((d2 % n) != 0)
return {}; return {};
rem1 = d2 / n;
}
std::vector<std::size_t> axes(distance(dims));
std::iota(axes.begin(), axes.end(), cd.dims.size());
cd.axes_map1.push_back(axes);
cd.axes_map2.push_back(axes);
cd.dims.insert(cd.dims.end(), dims.begin(), dims.end());
if(rem1 != 1)
cd.dims.push_back(rem1);
it1 += distance(dims);
it2++;
} }
} }
assert(elements(dims1) == elements(cd.dims));
return cd; return cd;
} }
......
...@@ -248,6 +248,9 @@ struct iterator_range ...@@ -248,6 +248,9 @@ struct iterator_range
Iterator begin() const { return start; } Iterator begin() const { return start; }
Iterator end() const { return last; } Iterator end() const { return last; }
bool empty() const { return start == last; }
decltype(auto) front() const { return *start; }
}; };
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
......
#include <migraphx/common_dims.hpp> #include <migraphx/common_dims.hpp>
#include <test.hpp> #include <test.hpp>
using axes_map = std::vector<std::vector<std::size_t>>;
TEST_CASE(common_d1_less)
{
auto cd = migraphx::common_dims::compute({2, 32, 40, 8}, {2, 1280, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2}, {3}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}});
}
TEST_CASE(common1) TEST_CASE(common1)
{ {
auto cd = migraphx::common_dims::compute({2, 32, 2560}, {2, 1280, 8, 8}); auto cd = migraphx::common_dims::compute({2, 32, 2560}, {2, 1280, 8, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8}); EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2, 3, 4}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}, {4}});
}
TEST_CASE(common2)
{
auto cd = migraphx::common_dims::compute({2, 1280, 8, 8}, {2, 32, 2560});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}, {4}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1}, {2, 3, 4}});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment