"src/replace_allocate.cpp" did not exist on "f59806195c88a1c7571ac4aa554ec4e682c18866"
Commit 7f0c6da9 authored by Paul's avatar Paul
Browse files

Fix tests

parent 9acc5aad
......@@ -13,31 +13,27 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
return x > dim;
});
if(x != dim)
if(x < dim)
return start;
return it;
}
template <class Iterator>
static auto elements(Iterator start, Iterator last)
{
return std::accumulate(start, last, std::size_t{1}, std::multiplies<>{});
}
template <class Range>
static auto elements(const Range& r)
{
return elements(r.begin(), r.end());
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}
struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims) : dims(&pdims), it(dims->begin()) {}
const std::vector<std::size_t>* dims;
std::vector<std::size_t>::const_iterator it;
common_dim_state(const std::vector<std::size_t>& pdims, std::vector<std::vector<std::size_t>>& paxes_map) : dims(&pdims), axes_map(&paxes_map), it(dims->begin()) {}
const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1;
std::size_t get() const { return *it; }
std::size_t get() const { return *it / rem; }
bool is_end() const { return it == dims->end(); }
void next(std::size_t i = 1) { it += i; }
auto dims_for(std::size_t d) const
......@@ -45,53 +41,80 @@ struct common_dim_state
auto dim_end = compute_end_dim(it, dims->end(), d);
return range(it, dim_end);
}
void add_axes(std::size_t naxes, std::size_t start)
{
auto axes = compute_axes(naxes, start);
axes_map->push_back(std::move(axes));
}
void add_multi_axes(std::size_t naxes, std::size_t start)
{
auto axes = compute_axes(naxes, start);
std::transform(axes.begin(), axes.end(), std::back_inserter(*axes_map), [&](auto axis) -> std::vector<std::size_t> {
return {axis};
});
}
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{
if (rem != 1) {
assert(start > 0);
naxes++;
start--;
}
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), start);
return axes;
}
};
common_dims
common_dims::compute(const std::vector<std::size_t>& dims1, const std::vector<std::size_t>& dims2)
static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_state& state1, common_dim_state& state2)
{
assert(state1.get() <= state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);
// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return true;
auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem != 1 ? naxes + 1 : naxes, cd_dims.size());
state1.rem = rem;
state2.rem = 1;
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state2.next();
return false;
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) == elements(dims2));
common_dims cd;
auto it1 = dims1.begin();
auto it2 = dims2.begin();
std::size_t rem1 = 1;
std::size_t rem2 = 1;
while(it1 != dims1.end() and it2 != dims2.end())
common_dim_state state1{dims1, cd.axes_map1};
common_dim_state state2{dims2, cd.axes_map2};
while(not state1.is_end() and not state2.is_end())
{
auto d1 = *it1;
auto d2 = *it2;
if(d1 == d2)
auto d1 = state1.get();
auto d2 = state2.get();
if(d1 <= d2)
{
cd.axes_map1.push_back({cd.dims.size()});
cd.axes_map2.push_back({cd.dims.size()});
cd.dims.push_back(d1);
it1++;
it2++;
if (commpute_common_dim(cd.dims, state1, state2))
return {};
}
else if(d1 < d2)
else // if(d1 > d2)
{
auto dim_end = compute_end_dim(it1, dims1.begin(), d2);
auto dims = range(it1, dim_end);
auto n = elements(dims);
if(n != d2)
{
// If not divisible then we can't compute a common dims
if((d2 % n) != 0)
return {};
rem1 = d2 / n;
}
std::vector<std::size_t> axes(distance(dims));
std::iota(axes.begin(), axes.end(), cd.dims.size());
cd.axes_map1.push_back(axes);
cd.axes_map2.push_back(axes);
cd.dims.insert(cd.dims.end(), dims.begin(), dims.end());
if(rem1 != 1)
cd.dims.push_back(rem1);
it1 += distance(dims);
it2++;
if (commpute_common_dim(cd.dims, state2, state1))
return {};
}
}
assert(elements(dims1) == elements(cd.dims));
return cd;
}
......
......@@ -248,6 +248,9 @@ struct iterator_range
Iterator begin() const { return start; }
Iterator end() const { return last; }
bool empty() const { return start == last; }
decltype(auto) front() const { return *start; }
};
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
......
#include <migraphx/common_dims.hpp>
#include <test.hpp>
using axes_map = std::vector<std::vector<std::size_t>>;
TEST_CASE(common_d1_less)
{
auto cd = migraphx::common_dims::compute({2, 32, 40, 8}, {2, 1280, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2}, {3}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}});
}
TEST_CASE(common1)
{
auto cd = migraphx::common_dims::compute({2, 32, 2560}, {2, 1280, 8, 8});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1}, {2, 3, 4}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}, {4}});
}
TEST_CASE(common2)
{
auto cd = migraphx::common_dims::compute({2, 1280, 8, 8}, {2, 32, 2560});
EXPECT(cd.dims == std::vector<std::size_t>{2, 32, 40, 8, 8});
EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}, {4}});
EXPECT(cd.axes_map2 == axes_map{{0}, {1}, {2, 3, 4}});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment