Commit da78b0c0 authored by Paul's avatar Paul
Browse files

Format

parent 7f0c6da9
...@@ -28,8 +28,12 @@ static auto elements(const Range& r) ...@@ -28,8 +28,12 @@ static auto elements(const Range& r)
struct common_dim_state struct common_dim_state
{ {
common_dim_state(const std::vector<std::size_t>& pdims, std::vector<std::vector<std::size_t>>& paxes_map) : dims(&pdims), axes_map(&paxes_map), it(dims->begin()) {} common_dim_state(const std::vector<std::size_t>& pdims,
const std::vector<std::size_t>* dims = nullptr; std::vector<std::vector<std::size_t>>& paxes_map)
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
{
}
const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr; std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{}; std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1; std::size_t rem = 1;
...@@ -50,13 +54,15 @@ struct common_dim_state ...@@ -50,13 +54,15 @@ struct common_dim_state
void add_multi_axes(std::size_t naxes, std::size_t start) void add_multi_axes(std::size_t naxes, std::size_t start)
{ {
auto axes = compute_axes(naxes, start); auto axes = compute_axes(naxes, start);
std::transform(axes.begin(), axes.end(), std::back_inserter(*axes_map), [&](auto axis) -> std::vector<std::size_t> { std::transform(axes.begin(),
return {axis}; axes.end(),
}); std::back_inserter(*axes_map),
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
} }
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{ {
if (rem != 1) { if(rem != 1)
{
assert(start > 0); assert(start > 0);
naxes++; naxes++;
start--; start--;
...@@ -67,12 +73,14 @@ struct common_dim_state ...@@ -67,12 +73,14 @@ struct common_dim_state
} }
}; };
static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_state& state1, common_dim_state& state2) static bool commpute_common_dim(std::vector<std::size_t>& cd_dims,
common_dim_state& state1,
common_dim_state& state2)
{ {
assert(state1.get() <= state2.get()); assert(state1.get() <= state2.get());
auto d2 = state2.get(); auto d2 = state2.get();
auto dims = state1.dims_for(d2); auto dims = state1.dims_for(d2);
auto n = elements(dims); auto n = elements(dims);
auto naxes = distance(dims); auto naxes = distance(dims);
// If not divisible then we can't compute a common dim // If not divisible then we can't compute a common dim
if((d2 % n) != 0) if((d2 % n) != 0)
...@@ -80,7 +88,7 @@ static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_st ...@@ -80,7 +88,7 @@ static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_st
auto rem = d2 / n; auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size()); state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem != 1 ? naxes + 1 : naxes, cd_dims.size()); state2.add_axes(rem != 1 ? naxes + 1 : naxes, cd_dims.size());
state1.rem = rem; state1.rem = rem;
state2.rem = 1; state2.rem = 1;
...@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, ...@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
auto d2 = state2.get(); auto d2 = state2.get();
if(d1 <= d2) if(d1 <= d2)
{ {
if (commpute_common_dim(cd.dims, state1, state2)) if(commpute_common_dim(cd.dims, state1, state2))
return {}; return {};
} }
else // if(d1 > d2) else // if(d1 > d2)
{ {
if (commpute_common_dim(cd.dims, state2, state1)) if(commpute_common_dim(cd.dims, state2, state1))
return {}; return {};
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment