Commit da78b0c0 authored by Paul's avatar Paul
Browse files

Format

parent 7f0c6da9
...@@ -28,7 +28,11 @@ static auto elements(const Range& r) ...@@ -28,7 +28,11 @@ static auto elements(const Range& r)
struct common_dim_state struct common_dim_state
{ {
common_dim_state(const std::vector<std::size_t>& pdims, std::vector<std::vector<std::size_t>>& paxes_map) : dims(&pdims), axes_map(&paxes_map), it(dims->begin()) {} common_dim_state(const std::vector<std::size_t>& pdims,
std::vector<std::vector<std::size_t>>& paxes_map)
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
{
}
const std::vector<std::size_t>* dims = nullptr; const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr; std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{}; std::vector<std::size_t>::const_iterator it{};
...@@ -50,13 +54,15 @@ struct common_dim_state ...@@ -50,13 +54,15 @@ struct common_dim_state
void add_multi_axes(std::size_t naxes, std::size_t start) void add_multi_axes(std::size_t naxes, std::size_t start)
{ {
auto axes = compute_axes(naxes, start); auto axes = compute_axes(naxes, start);
std::transform(axes.begin(), axes.end(), std::back_inserter(*axes_map), [&](auto axis) -> std::vector<std::size_t> { std::transform(axes.begin(),
return {axis}; axes.end(),
}); std::back_inserter(*axes_map),
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
} }
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{ {
if (rem != 1) { if(rem != 1)
{
assert(start > 0); assert(start > 0);
naxes++; naxes++;
start--; start--;
...@@ -67,7 +73,9 @@ struct common_dim_state ...@@ -67,7 +73,9 @@ struct common_dim_state
} }
}; };
static bool commpute_common_dim(std::vector<std::size_t>& cd_dims, common_dim_state& state1, common_dim_state& state2) static bool commpute_common_dim(std::vector<std::size_t>& cd_dims,
common_dim_state& state1,
common_dim_state& state2)
{ {
assert(state1.get() <= state2.get()); assert(state1.get() <= state2.get());
auto d2 = state2.get(); auto d2 = state2.get();
...@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, ...@@ -105,12 +113,12 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
auto d2 = state2.get(); auto d2 = state2.get();
if(d1 <= d2) if(d1 <= d2)
{ {
if (commpute_common_dim(cd.dims, state1, state2)) if(commpute_common_dim(cd.dims, state1, state2))
return {}; return {};
} }
else // if(d1 > d2) else // if(d1 > d2)
{ {
if (commpute_common_dim(cd.dims, state2, state1)) if(commpute_common_dim(cd.dims, state2, state1))
return {}; return {};
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment