"vscode:/vscode.git/clone" did not exist on "6a09cedf0fb80930d5d4153cdf4ed1d182233718"
Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
namespace cinn::adt {
class AutoSize final {};
// LoopSize = Int64
DEFINE_ADT_UNION(LoopSize, std::int64_t);
// S(Spatial): S0 = BlockIdx; S1 = ThreadIdx
// LoopType = S0x | S0y | S0z | S1x | S1y | S1z | Temporal | Vectorize |
// Unroll
class S0x final {
public:
bool IsSpatial() const { return true; }
};
class S0y final {
public:
bool IsSpatial() const { return true; }
};
class S0z final {
public:
bool IsSpatial() const { return true; }
};
class S1x final {
public:
bool IsSpatial() const { return true; }
};
class S1y final {
public:
bool IsSpatial() const { return true; }
};
class S1z final {
public:
bool IsSpatial() const { return true; }
};
class Temporal final {
public:
bool IsSpatial() const { return false; }
const std::string& iter_var_name() const { return iter_var_name_; }
private:
std::string iter_var_name_;
};
class Vectorize final {
public:
bool IsSpatial() const { return false; }
const std::string& iter_var_name() const { return iter_var_name_; }
private:
std::string iter_var_name_;
};
class Unroll final {
public:
bool IsSpatial() const { return false; }
const std::string& iter_var_name() const { return iter_var_name_; }
private:
std::string iter_var_name_;
};
DEFINE_ADT_UNION(
LoopType, S0x, S0y, S0z, S1x, S1y, S1z, Temporal, Vectorize, Unroll);
// LoopDescriptor = (LoopType, LoopSize)
class LoopDescriptor final : public Tuple<LoopType, LoopSize> {
public:
using Tuple<LoopType, LoopSize>::Tuple;
const LoopType& GetLoopType() const { return std::get<0>(this->tuple()); }
const LoopSize& GetLoopSize() const { return std::get<1>(this->tuple()); }
bool operator==(const LoopDescriptor& other) const {
return &this->tuple() == &other.tuple();
}
};
// LoopDescriptors = [LoopDescriptor]
using LoopDescriptors = List<LoopDescriptor>;
inline bool IsSpatial(const LoopType& loop_type) {
return std::visit([](const auto& impl) { return impl.IsSpatial(); },
loop_type.variant());
}
List<LoopSize> GenerateLoopSizeFromSd(const LoopDescriptors& sd);
class KGroup;
class IGroup;
class ScheduleMesh;
LoopDescriptors CreateScheduleDescriptor(const ScheduleMesh& sched_mesh,
const List<LoopType>& loop_types);
LoopDescriptors MakeNaiveScheduleDescriptor(
const std::shared_ptr<KGroup>& kgroup,
const std::shared_ptr<IGroup>& igroup);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/schedule_dim.h"
#include "paddle/cinn/adt/equation_function_constants_provider.h"
#include "paddle/cinn/adt/equation_graph.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/igroup.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/m_ir.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
#include "paddle/cinn/adt/print.h"
namespace cinn::adt {
namespace {
template <typename DoEachT>
void VisitEachOpEquationContext(const IGroup& igroup, const DoEachT& DoEach) {
for (const auto& op_stmt : *igroup.op_stmts()) {
const auto& EquationCtx4OpStmt = igroup.EquationCtx4OpStmt();
const auto& ctx = EquationCtx4OpStmt(op_stmt);
DoEach(ctx);
}
}
List<Iterator> GetOpEquationCtxInputIterators(
const std::shared_ptr<config::NaiveOpEquationContext>& ctx) {
List<Iterator> ret{};
std::size_t input_size = ctx->in_indexes()->size();
for (std::size_t i = 0; i < input_size; ++i) {
for (const auto& iterator : *ctx->GetInIteratorTuple(i)) {
ret->emplace_back(iterator);
}
}
return ret;
}
std::shared_ptr<IndexExprInferContext> InitIndexExprInferContext(
const std::shared_ptr<config::NaiveOpEquationContext>& ctx,
const List<Iterator>& input_iterators,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider) {
std::unordered_map<Variable, const Value> init_var2value;
for (const auto& iterator : *input_iterators) {
CHECK(init_var2value.emplace(iterator, iterator).second);
}
return std::make_shared<IndexExprInferContext>(init_var2value,
constants_provider);
}
template <typename DoEachT>
void VisitEachInputIteratorTuple(
const std::shared_ptr<config::NaiveOpEquationContext>& op_ctx,
const DoEachT& DoEach) {
std::size_t input_size = op_ctx->in_indexes()->size();
for (std::size_t i = 0; i < input_size; ++i) {
DoEach(op_ctx->GetInIteratorTuple(i));
}
}
template <typename DoEachT>
void VisitEachOutputIterator(
const std::shared_ptr<config::NaiveOpEquationContext>& op_ctx,
const DoEachT& DoEach) {
std::size_t output_size = op_ctx->out_indexes()->size();
for (std::size_t i = 0; i < output_size; ++i) {
for (const auto& output_iterator : *op_ctx->GetOutIteratorTuple(i)) {
DoEach(output_iterator);
}
}
}
void FilterReducedIterator(
const std::shared_ptr<IndexExprInferContext>& infer_ctx,
const std::shared_ptr<config::NaiveOpEquationContext>& op_ctx,
const List<Iterator>& input_iterators,
std::unordered_set<Iterator>* unused_input_iterators) {
std::unordered_set<Iterator> used{};
bool is_output_infered = true;
VisitEachOutputIterator(op_ctx, [&](const Iterator& output_iterator) {
if (infer_ctx->HasValue(output_iterator)) {
const auto& iterator_expr = infer_ctx->GetValue(output_iterator);
CollectTensorIndexIterators(iterator_expr, &used);
} else {
is_output_infered = false;
}
});
if (!is_output_infered) {
return;
}
for (const auto& input_iterator : *input_iterators) {
if (used.find(input_iterator) == used.end()) {
unused_input_iterators->emplace(input_iterator);
}
}
}
std::unordered_set<Iterator> GenerateReducedIterator(
const std::shared_ptr<config::NaiveOpEquationContext>& ctx,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider) {
const auto& graph_view = Graph::New(ctx->equations())->GetGraphView();
std::unordered_set<Iterator> ret{};
VisitEachInputIteratorTuple(ctx, [&](const List<Iterator>& input_iterators) {
const auto& infer_ctx =
InitIndexExprInferContext(ctx, input_iterators, constants_provider);
std::vector<Variable> starts{};
for (const auto& iterator : *input_iterators) {
starts.emplace_back(iterator);
}
SolveEquations(graph_view, starts, infer_ctx.get());
/*
y = Reduce(x)
y_i = f(x_i, x_j, ...)
used_set = {x_i, x_j, ...}
reduce_iterator = input_all_iterator_set - used_set
*/
FilterReducedIterator(infer_ctx, ctx, input_iterators, &ret);
});
return ret;
}
std::unordered_set<Iterator> FilterTemporalIterators(
const IGroup& igroup,
const std::function<Value(const Iterator&)>& Value4Iterator) {
std::unordered_set<Iterator> ret{};
VisitEachOpEquationContext(
igroup, [&](const std::shared_ptr<config::NaiveOpEquationContext>& ctx) {
std::unordered_set<Iterator> reduced_iterators =
GenerateReducedIterator(ctx, igroup.constants_provider());
for (const auto& input_reduced_iterator : reduced_iterators) {
const auto& sd_iterator_expr = Value4Iterator(input_reduced_iterator);
CollectTensorIndexIterators(sd_iterator_expr, &ret);
}
});
return ret;
}
} // namespace
List<ScheduleDim> MakeAnchorScheduleDims(
const IGroup& igroup,
const std::function<Value(const Iterator&)>& Value4Iterator,
const List<LoopSize>& loop_sizes,
const List<Iterator>& anchor_iterators) {
std::unordered_set<Iterator> temporal_sd_iterators =
FilterTemporalIterators(igroup, Value4Iterator);
List<ScheduleDim> ret{};
for (std::size_t i = 0; i < loop_sizes->size(); ++i) {
const auto& loop_iterator = anchor_iterators->at(i);
if (temporal_sd_iterators.count(loop_iterator) > 0) {
ret->emplace_back(tReduced<LoopSize>{loop_sizes->at(i)});
} else {
ret->emplace_back(tInjective<LoopSize>{loop_sizes->at(i)});
}
}
return ret;
}
LoopSize GetLoopSize(const ScheduleDim& sched_dim) {
return std::visit([&](const auto& impl) { return impl.value(); },
sched_dim.variant());
}
List<int> GetReduceAxis(const List<ScheduleDim>& loop_sizes) {
List<int> reduce_axis{};
for (std::size_t i = 0; i < loop_sizes->size(); ++i) {
const auto& sched_dim = loop_sizes->at(i);
if (sched_dim.Has<tReduced<LoopSize>>()) {
reduce_axis->emplace_back(i);
} else if (sched_dim.Has<tInjective<LoopSize>>()) {
// do nothing
} else {
LOG(FATAL) << "Dead code";
}
}
return reduce_axis;
}
List<int> GetInjectiveAxis(const List<ScheduleDim>& loop_sizes) {
List<int> injective_axis{};
for (std::size_t i = 0; i < loop_sizes->size(); ++i) {
const auto& sched_dim = loop_sizes->at(i);
if (sched_dim.Has<tReduced<LoopSize>>()) {
// do nothing
} else if (sched_dim.Has<tInjective<LoopSize>>()) {
injective_axis->emplace_back(i);
} else {
LOG(FATAL) << "Dead code";
}
}
return injective_axis;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <optional>
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/equation_variable.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
namespace cinn::adt {
DEFINE_ADT_TAG(tReduced);
DEFINE_ADT_TAG(tInjective);
DEFINE_ADT_UNION(ScheduleDim, tReduced<LoopSize>, tInjective<LoopSize>);
LoopSize GetLoopSize(const ScheduleDim& sched_dim);
List<int> GetReduceAxis(const List<ScheduleDim>& loop_sizes);
List<int> GetInjectiveAxis(const List<ScheduleDim>& loop_sizes);
class IGroup;
List<ScheduleDim> MakeAnchorScheduleDims(
const IGroup& igroup,
const std::function<Value(const Iterator&)>& Value4Iterator,
const List<LoopSize>& loop_sizes,
const List<Iterator>& anchor_iterators);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/schedule_mesh.h"
namespace cinn::adt {
namespace {
std::size_t GetInputRankImpl(const List<ScheduleDim>& sched_dims) {
return sched_dims->size();
}
std::size_t GetInputRankImpl(
const ScheduleMeshReshape<ScheduleMesh>& sched_reshape) {
const auto& [sched_mesh, _] = sched_reshape.tuple();
return GetInputRank(sched_mesh);
}
std::size_t GetInputRankImpl(
const ScheduleMeshTranspose<ScheduleMesh>& sched_transpose) {
const auto& [sched_mesh, _] = sched_transpose.tuple();
return GetInputRank(sched_mesh);
}
std::size_t GetInputRankImpl(
const ScheduleMeshPadding<ScheduleMesh>& sched_padding) {
const auto& [sched_mesh, _] = sched_padding.tuple();
return GetInputRank(sched_mesh);
}
} // namespace
std::size_t GetInputRank(const ScheduleMesh& sched_mesh) {
return std::visit([&](const auto& impl) { return GetInputRankImpl(impl); },
sched_mesh.variant());
}
namespace {
std::size_t GetOutputRankImpl(const List<ScheduleDim>& sched_dims) {
return sched_dims->size();
}
std::size_t GetOutputRankImpl(
const ScheduleMeshReshape<ScheduleMesh>& sched_reshape) {
const auto& [_, shapes] = sched_reshape.tuple();
return shapes.value()->size();
}
std::size_t GetOutputRankImpl(
const ScheduleMeshTranspose<ScheduleMesh>& sched_transpose) {
const auto& [sched_mesh, perm] = sched_transpose.tuple();
CHECK_EQ(GetOutputRank(sched_mesh), perm.value()->size());
return perm.value()->size();
}
std::size_t GetOutputRankImpl(
const ScheduleMeshPadding<ScheduleMesh>& sched_padding) {
const auto& [_, padding_to] = sched_padding.tuple();
return padding_to.value()->size();
}
} // namespace
std::size_t GetOutputRank(const ScheduleMesh& sched_mesh) {
return std::visit([&](const auto& impl) { return GetOutputRankImpl(impl); },
sched_mesh.variant());
}
namespace {
List<Constant> GetOutputDimValuesImpl(const List<ScheduleDim>& sched_dims) {
List<Constant> ret{};
for (const auto& sched_dim : *sched_dims) {
const auto& loop_size = GetLoopSize(sched_dim);
CHECK(loop_size.Has<std::int64_t>());
ret->emplace_back(loop_size.Get<std::int64_t>());
}
return ret;
}
List<Constant> GetOutputDimValuesImpl(
const ScheduleMeshReshape<ScheduleMesh>& sched_reshape) {
const auto& [_, shape] = sched_reshape.tuple();
List<Constant> ret{};
for (const auto& loop_size : *shape.value()) {
CHECK(loop_size.Has<std::int64_t>());
ret->emplace_back(loop_size.Get<std::int64_t>());
}
return ret;
}
List<Constant> GetOutputDimValuesImpl(
const ScheduleMeshTranspose<ScheduleMesh>& sched_transpose) {
const auto& [sched_mesh, perm] = sched_transpose.tuple();
const auto& input_dims = GetOutputDimValues(sched_mesh);
List<Constant> ret{};
for (const auto& idx : *perm.value()) {
ret->emplace_back(input_dims->at(idx));
}
return ret;
}
List<Constant> GetOutputDimValuesImpl(
const ScheduleMeshPadding<ScheduleMesh>& sched_padding) {
const auto& [_, shape] = sched_padding.tuple();
List<Constant> ret{};
for (const auto& loop_size : *shape.value()) {
CHECK(loop_size.Has<std::int64_t>());
ret->emplace_back(loop_size.Get<std::int64_t>());
}
return ret;
}
} // namespace
List<Constant> GetOutputDimValues(const ScheduleMesh& sched_mesh) {
return std::visit(
[&](const auto& impl) { return GetOutputDimValuesImpl(impl); },
sched_mesh.variant());
}
namespace {
ScheduleMesh GetInputScheduleMeshImpl(const List<ScheduleDim>& sched_dims) {
return sched_dims;
}
ScheduleMesh GetInputScheduleMeshImpl(
const ScheduleMeshReshape<ScheduleMesh>& sched_reshape) {
const auto& [sched_mesh, _] = sched_reshape.tuple();
return GetInputScheduleMesh(sched_mesh);
}
ScheduleMesh GetInputScheduleMeshImpl(
const ScheduleMeshTranspose<ScheduleMesh>& sched_transpose) {
const auto& [sched_mesh, _] = sched_transpose.tuple();
return GetInputScheduleMesh(sched_mesh);
}
ScheduleMesh GetInputScheduleMeshImpl(
const ScheduleMeshPadding<ScheduleMesh>& sched_padding) {
const auto& [sched_mesh, _] = sched_padding.tuple();
return GetInputScheduleMesh(sched_mesh);
}
} // namespace
ScheduleMesh GetInputScheduleMesh(const ScheduleMesh& sched_mesh) {
return std::visit(
[&](const auto& impl) { return GetInputScheduleMeshImpl(impl); },
sched_mesh.variant());
}
namespace {
constexpr int kThreadSize = 1024;
class ScheduleMeshPolicy {
public:
ScheduleMeshPolicy(const ScheduleMeshPolicy&) = delete;
ScheduleMeshPolicy(ScheduleMeshPolicy&&) = delete;
virtual ~ScheduleMeshPolicy() = default;
virtual bool Match(const List<ScheduleDim>& loop_sizes) const = 0;
virtual std::tuple<ScheduleMesh, List<LoopType>> Optimize(
const List<ScheduleDim>& loop_sizes) const = 0;
protected:
ScheduleMeshPolicy() = default;
};
class AllInjectiveScheduleMeshPolicy final : public ScheduleMeshPolicy {
public:
AllInjectiveScheduleMeshPolicy() = default;
bool Match(const List<ScheduleDim>& loop_sizes) const override {
for (const auto& sched_dim : *loop_sizes) {
if (!sched_dim.Has<tInjective<LoopSize>>()) {
return false;
}
if (!GetLoopSize(sched_dim).Has<std::int64_t>()) {
return false;
}
}
return true;
}
std::tuple<ScheduleMesh, List<LoopType>> Optimize(
const List<ScheduleDim>& loop_sizes) const override {
ScheduleMesh sched_mesh{loop_sizes};
std::int64_t acc = 1;
for (const auto& sched_dim : *loop_sizes) {
acc *= GetLoopSize(sched_dim).Get<std::int64_t>();
}
sched_mesh = MeshReshape(sched_mesh, {acc});
sched_mesh = MeshPaddingRoundUp(sched_mesh, {kThreadSize});
sched_mesh = MeshReshape(sched_mesh, {-1, kThreadSize});
return std::make_tuple(sched_mesh, List<LoopType>{S0x{}, S1x{}});
}
};
List<int> ConcatIntLists(const List<int>& lhs, const List<int>& rhs) {
List<int> ret{};
for (int i : *lhs) {
ret->emplace_back(i);
}
for (int i : *rhs) {
ret->emplace_back(i);
}
return ret;
}
std::vector<std::int64_t> ConcatIntLists(const std::vector<std::int64_t>& lhs,
const std::vector<std::int64_t>& rhs) {
std::vector<std::int64_t> ret{};
for (int i : lhs) {
ret.emplace_back(i);
}
for (int i : rhs) {
ret.emplace_back(i);
}
return ret;
}
std::vector<std::optional<std::int64_t>> ConcatIntListsToOptionalList(
const std::vector<std::int64_t>& lhs,
const std::vector<std::int64_t>& rhs) {
std::vector<std::optional<std::int64_t>> ret{};
for (int i : lhs) {
ret.emplace_back(i);
}
for (int i : rhs) {
ret.emplace_back(i);
}
return ret;
}
class GeneralScheduleMeshPolicy final : public ScheduleMeshPolicy {
public:
GeneralScheduleMeshPolicy() = default;
bool Match(const List<ScheduleDim>& loop_sizes) const override {
for (const auto& sched_dim : *loop_sizes) {
if (!GetLoopSize(sched_dim).Has<std::int64_t>()) {
return false;
}
}
return true;
}
std::tuple<ScheduleMesh, List<LoopType>> Optimize(
const List<ScheduleDim>& loop_sizes) const override {
const auto& injective_axes = GetInjectiveAxis(loop_sizes);
const auto& reduce_axes = GetReduceAxis(loop_sizes);
std::vector<std::int64_t> reduce_shape{};
for (int reduce_axis : *reduce_axes) {
reduce_shape.emplace_back(
GetLoopSize(loop_sizes->at(reduce_axis)).Get<std::int64_t>());
}
ScheduleMesh sched_mesh{loop_sizes};
sched_mesh =
MeshTranspose(sched_mesh, ConcatIntLists(injective_axes, reduce_axes));
sched_mesh = MeshReshape(sched_mesh, ConcatIntLists({-1}, reduce_shape));
sched_mesh = MeshPaddingRoundUp(
sched_mesh, ConcatIntListsToOptionalList({kThreadSize}, reduce_shape));
sched_mesh = MeshReshape(sched_mesh,
ConcatIntLists({-1, kThreadSize}, reduce_shape));
List<LoopType> loop_types{S0x{}, S1x{}};
for (std::size_t i = 0; i < reduce_axes->size(); ++i) {
loop_types->emplace_back(Temporal{});
}
return std::make_tuple(sched_mesh, loop_types);
}
};
const std::vector<std::unique_ptr<ScheduleMeshPolicy>>&
GetAllScheduleMeshPolicies() {
static std::vector<std::unique_ptr<ScheduleMeshPolicy>> policies{};
policies.emplace_back(std::make_unique<AllInjectiveScheduleMeshPolicy>());
policies.emplace_back(std::make_unique<GeneralScheduleMeshPolicy>());
return policies;
}
} // namespace
std::tuple<ScheduleMesh, List<LoopType>> CreateOptimizedScheduleMesh(
const List<ScheduleDim>& loop_sizes) {
for (const auto& policy : GetAllScheduleMeshPolicies()) {
if (policy->Match(loop_sizes)) {
return policy->Optimize(loop_sizes);
}
}
LOG(FATAL) << "Dead code, no valid schedule mesh policy found";
}
ScheduleMesh MeshReshape(const ScheduleMesh& sched_mesh,
const std::vector<std::int64_t>& shape) {
const auto& origin_shape = GetOutputDimValues(sched_mesh);
std::int64_t origin_numel = 1;
for (const auto& dim : *origin_shape) {
CHECK(dim.Has<std::int64_t>());
origin_numel *= dim.Get<std::int64_t>();
}
std::int64_t numel = 1;
bool dynamic_shape = false;
for (const auto& dim : shape) {
if (dim < 0) {
CHECK(dim == -1 && !dynamic_shape);
dynamic_shape = true;
} else {
numel *= dim;
}
}
CHECK(dynamic_shape || numel == origin_numel);
List<LoopSize> reshape_to{};
for (const auto& dim : shape) {
if (dim < 0) {
CHECK_EQ(origin_numel % numel, 0);
reshape_to->emplace_back(origin_numel / numel);
} else {
reshape_to->emplace_back(dim);
}
}
return ScheduleMeshReshape<ScheduleMesh>(sched_mesh, reshape_to);
}
ScheduleMesh MeshTranspose(const ScheduleMesh& sched_mesh,
const List<int>& perm) {
return ScheduleMeshTranspose<ScheduleMesh>{sched_mesh, perm};
}
ScheduleMesh MeshPadding(const ScheduleMesh& sched_mesh,
const List<LoopSize>& padding_to) {
const auto& ret = ScheduleMeshPadding<ScheduleMesh>(sched_mesh, padding_to);
const auto& input_dims = GetOutputDimValues(sched_mesh);
const auto& output_dims = GetOutputDimValues(ret);
CHECK_EQ(input_dims->size(), output_dims->size());
for (std::size_t i = 0; i < input_dims->size(); ++i) {
if (input_dims->at(i).Has<std::int64_t>() &&
output_dims->at(i).Has<std::int64_t>()) {
CHECK_LE(input_dims->at(i).Get<std::int64_t>(),
output_dims->at(i).Get<std::int64_t>());
}
}
return ret;
}
ScheduleMesh MeshPaddingRoundUp(
const ScheduleMesh& sched_mesh,
const std::vector<std::optional<std::int64_t>>& align_sizes) {
const auto& shape = GetOutputDimValues(sched_mesh);
CHECK_EQ(shape->size(), align_sizes.size());
List<LoopSize> padding_to{};
bool create_new_sched_mesh = false;
for (std::size_t i = 0; i < shape->size(); ++i) {
if (!align_sizes.at(i).has_value()) {
continue;
}
std::int64_t align_size = align_sizes.at(i).value();
CHECK(shape->at(i).Has<std::int64_t>());
std::int64_t dim = shape->at(i).Get<std::int64_t>();
std::int64_t padding_size =
(dim + align_size - 1) / align_size * align_size;
if (padding_size != dim) {
create_new_sched_mesh = true;
}
padding_to->emplace_back(padding_size);
}
if (!create_new_sched_mesh) {
return sched_mesh;
}
return ScheduleMeshPadding<ScheduleMesh>(sched_mesh, padding_to);
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/schedule_dim.h"
namespace cinn::adt {
DEFINE_ADT_TAG(tMeshDim);
DEFINE_ADT_TAG(tMeshPerm);
DEFINE_ADT_TAG(tMeshPaddingTo);
template <typename T>
class ScheduleMeshReshape final : public Tuple<T, tMeshDim<List<LoopSize>>> {
public:
using Tuple<T, tMeshDim<List<LoopSize>>>::Tuple;
};
template <typename T>
class ScheduleMeshTranspose final : public Tuple<T, tMeshPerm<List<int>>> {
public:
using Tuple<T, tMeshPerm<List<int>>>::Tuple;
};
template <typename T>
class ScheduleMeshPadding final
: public Tuple<T, tMeshPaddingTo<List<LoopSize>>> {
public:
using Tuple<T, tMeshPaddingTo<List<LoopSize>>>::Tuple;
};
DEFINE_ADT_UNION(ScheduleMesh,
List<ScheduleDim>,
ScheduleMeshReshape<ScheduleMesh>,
ScheduleMeshTranspose<ScheduleMesh>,
ScheduleMeshPadding<ScheduleMesh>);
ScheduleMesh MeshReshape(const ScheduleMesh& sched_mesh,
const std::vector<std::int64_t>& shape);
ScheduleMesh MeshTranspose(const ScheduleMesh& sched_mesh,
const List<int>& perm);
ScheduleMesh MeshPadding(const ScheduleMesh& sched_mesh,
const List<LoopSize>& padding_to);
ScheduleMesh MeshPaddingRoundUp(
const ScheduleMesh& sched_mesh,
const std::vector<std::optional<std::int64_t>>& align_size);
std::size_t GetInputRank(const ScheduleMesh& sched_mesh);
std::size_t GetOutputRank(const ScheduleMesh& sched_mesh);
List<Constant> GetOutputDimValues(const ScheduleMesh& sched_mesh);
ScheduleMesh GetInputScheduleMesh(const ScheduleMesh& sched_mesh);
std::tuple<ScheduleMesh, List<LoopType>> CreateOptimizedScheduleMesh(
const List<ScheduleDim>& loop_sizes);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <optional>
#include <typeinfo>
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_value_match_trait.h"
#include "paddle/cinn/adt/get_sub_reshape_dim_ranges.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/match.h"
#include "paddle/cinn/adt/simplify_value.h"
namespace cinn::adt {
template <typename T, typename ExprT>
ExprT MatchAndRewrite(const ExprT& expr, const IndexExprInferContext& ctx) {
if (cinn::adt::Match<typename T::source_pattern_type>(expr)) {
return T().MatchAndRewrite(expr, ctx);
} else {
return expr;
}
}
struct SimplifyBroadcastedIteratorByReplacingToStaticDim {
using source_pattern_type = BroadcastedIterator<Value, Dim>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [iterator, dim] =
value.Get<BroadcastedIterator<Value, Constant>>().tuple();
const Constant& int64_dim = ctx.GetDimSize(dim.Get<Dim>());
if (int64_dim.Has<std::int64_t>()) {
return BroadcastedIterator<Value, Constant>{
iterator, int64_dim.Get<std::int64_t>()};
} else {
return value;
}
}
};
struct SimplifyBroadcastedIterator {
using source_pattern_type = BroadcastedIterator<Value, std::int64_t>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [iterator, dim] =
value.Get<BroadcastedIterator<Value, Constant>>().tuple();
if (dim.Get<std::int64_t>() == 1) {
return Constant{std::int64_t(0)};
} else {
return iterator;
}
}
};
struct SimplifyDot {
using source_pattern_type = IndexDotValue<Value, List<Dim>>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [iterators, dims_constants] =
value.Get<IndexDotValue<Value, Constant>>().tuple();
List<Constant> int64_dims{};
for (const auto& dim_constant : *dims_constants.Get<List<Constant>>()) {
const Constant& int64_dim = ctx.GetDimSize(dim_constant.Get<Dim>());
if (int64_dim.Has<std::int64_t>()) {
int64_dims->emplace_back(int64_dim);
} else {
return IndexDotValue<Value, Constant>{SimplifyValue(iterators, ctx),
dims_constants};
}
}
return IndexDotValue<Value, Constant>{SimplifyValue(iterators, ctx),
int64_dims};
}
};
struct SimplifyUnDot {
using source_pattern_type = IndexUnDotValue<Value, List<Dim>>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [index, dims_constants] =
value.Get<IndexUnDotValue<Value, Constant>>().tuple();
List<Constant> int64_dims{};
for (const auto& dim_constant : *dims_constants.Get<List<Constant>>()) {
const Constant& int64_dim = ctx.GetDimSize(dim_constant.Get<Dim>());
if (int64_dim.Has<std::int64_t>()) {
int64_dims->emplace_back(int64_dim);
} else {
return IndexUnDotValue<Value, Constant>{SimplifyValue(index, ctx),
dims_constants};
}
}
return IndexUnDotValue<Value, Constant>{SimplifyValue(index, ctx),
int64_dims};
}
};
struct SimplifyList {
using source_pattern_type = List<Value>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
List<Value> ret{};
for (const auto& v : *value.Get<List<Value>>()) {
ret->emplace_back(SimplifyValue(v, ctx));
}
return ret;
}
};
struct SimplifyDotUndot {
using source_pattern_type =
IndexDotValue<List<ListGetItem<IndexUnDotValue<Value, List<std::int64_t>>,
std::int64_t>>,
List<std::int64_t>>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [list_get_item_values, dot_dims] =
value.Get<IndexDotValue<Value, Constant>>().tuple();
const auto& list_get_items = list_get_item_values.Get<List<Value>>();
std::optional<Value> pre_index_undot{std::nullopt};
for (std::size_t i = 0; i < list_get_items->size(); ++i) {
const auto& [index_undot_value, constant_idx] =
list_get_items.Get(i).Get<ListGetItem<Value, Constant>>().tuple();
if (!constant_idx.Get<std::int64_t>() == i) {
return IndexDotValue<Value, Constant>{
SimplifyValue(list_get_item_values, ctx), dot_dims};
}
if (pre_index_undot.has_value()) {
if (!(pre_index_undot.value() == index_undot_value)) {
return IndexDotValue<Value, Constant>{
SimplifyValue(list_get_item_values, ctx), dot_dims};
} else {
// do nothing
}
} else {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
const auto& [index_value, undot_dims] =
pre_index_undot.value().Get<IndexUnDotValue<Value, Constant>>().tuple();
CHECK(dot_dims.Has<List<Constant>>());
CHECK(undot_dims.Has<List<Constant>>());
if (dot_dims == undot_dims) {
return index_value;
}
return IndexDotValue<Value, Constant>{
SimplifyValue(list_get_item_values, ctx), dot_dims};
}
};
struct SimplifyUndotDot {
using source_pattern_type = ListGetItem<
IndexUnDotValue<IndexDotValue<List<Value>, List<std::int64_t>>,
List<std::int64_t>>,
std::int64_t>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [index_undot_value, constant_idx] =
value.Get<ListGetItem<Value, Constant>>().tuple();
const auto& [index_value, undot_dims] =
index_undot_value.Get<IndexUnDotValue<Value, Constant>>().tuple();
const auto& [index_dot_values, dot_dims] =
index_value.Get<IndexDotValue<Value, Constant>>().tuple();
const auto& iter_values = index_dot_values.Get<List<Value>>();
CHECK(dot_dims.Has<List<Constant>>());
CHECK(undot_dims.Has<List<Constant>>());
if (dot_dims == undot_dims) {
return iter_values.Get(constant_idx.Get<std::int64_t>());
} else {
return ListGetItem<Value, Constant>{SimplifyValue(index_undot_value, ctx),
constant_idx};
}
}
};
struct SimplifyListGetItem {
using source_pattern_type = ListGetItem<Value, Constant>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [list_values, constant_idx] =
value.Get<ListGetItem<Value, Constant>>().tuple();
return ListGetItem<Value, Constant>{SimplifyValue(list_values, ctx),
constant_idx};
}
};
struct SimplifyListGetItemList {
using source_pattern_type = ListGetItem<List<Value>, std::int64_t>;
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [list_values, constant_idx] =
value.Get<ListGetItem<Value, Constant>>().tuple();
const auto& iter_values = list_values.Get<List<Value>>();
return iter_values.Get(constant_idx.Get<std::int64_t>());
}
};
struct SimplifyGcdShape {
using source_pattern_type = ListGetItem<
IndexUnDotValue<IndexDotValue<List<Value>, List<std::int64_t>>,
List<std::int64_t>>,
std::int64_t>;
bool IsConstantListAllPositiveInt64(const List<Constant>& constants) {
for (const auto& constant : *constants) {
if (!constant.Has<std::int64_t>() || constant.Get<std::int64_t>() <= 0) {
return false;
}
}
return true;
}
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [index_undot_value, constant_idx] =
value.Get<ListGetItem<Value, Constant>>().tuple();
const auto& [index_value, undot_dims] =
index_undot_value.Get<IndexUnDotValue<Value, Constant>>().tuple();
const auto& [index_dot_values, dot_dims] =
index_value.Get<IndexDotValue<Value, Constant>>().tuple();
const auto& iter_values = index_dot_values.Get<List<Value>>();
CHECK(dot_dims.Has<List<Constant>>());
CHECK(undot_dims.Has<List<Constant>>());
const auto& undot_dim_values = undot_dims.Get<List<Constant>>();
const auto& dot_dim_values = dot_dims.Get<List<Constant>>();
CHECK(IsConstantListAllPositiveInt64(undot_dim_values));
CHECK(IsConstantListAllPositiveInt64(dot_dim_values));
const auto& sub_reshape_dim_ranges =
GetSubReshapeDimRanges(undot_dim_values, dot_dim_values);
if (!sub_reshape_dim_ranges.has_value()) {
return ListGetItem<Value, Constant>{SimplifyValue(index_undot_value, ctx),
constant_idx};
}
const auto& [undot_dim_ranges, dot_dim_ranges] =
sub_reshape_dim_ranges.value();
if (undot_dim_ranges.size() >= 1) {
const auto& [sub_range_idx, sub_range_item_idx] = GetSubRangeItemIdx(
undot_dim_ranges, constant_idx.Get<std::int64_t>());
List<Constant> sub_range_undot_dims = GetSubRangeDotDims(
undot_dim_values, undot_dim_ranges.at(sub_range_idx));
List<Value> sub_range_dot_iterators = GetSubRangeDotIterators(
iter_values, dot_dim_ranges.at(sub_range_idx));
List<Constant> sub_range_dot_dims =
GetSubRangeDotDims(dot_dim_values, dot_dim_ranges.at(sub_range_idx));
if (sub_range_dot_dims == sub_range_undot_dims) {
return sub_range_dot_iterators.Get(sub_range_item_idx);
} else {
IndexDotValue<Value, Constant> sub_range_dot{sub_range_dot_iterators,
sub_range_dot_dims};
if (sub_range_undot_dims->size() == 1) {
CHECK_EQ(sub_range_item_idx, 0);
return sub_range_dot;
} else {
IndexUnDotValue<Value, Constant> sub_range_undot{
sub_range_dot, sub_range_undot_dims};
return ListGetItem<Value, Constant>{sub_range_undot,
sub_range_item_idx};
}
}
}
return ListGetItem<Value, Constant>{SimplifyValue(index_undot_value, ctx),
constant_idx};
}
std::pair<int, int> GetSubRangeItemIdx(
const std::vector<std::pair<int, int>>& ranges,
std::int64_t index) const {
for (std::size_t i = 0; i < ranges.size(); ++i) {
const auto& [begin, end] = ranges.at(i);
if (index >= begin && index < end) {
return std::pair<int, int>{i, index - begin};
}
}
}
List<Value> GetSubRangeDotIterators(const List<Value>& iterators,
const std::pair<int, int>& range) const {
return GetSubRange<List<Value>>(iterators, range);
}
List<Constant> GetSubRangeDotDims(const List<Constant>& dims,
const std::pair<int, int>& range) const {
return GetSubRange<List<Constant>>(dims, range);
}
template <typename ContainerT>
ContainerT GetSubRange(const ContainerT& container,
const std::pair<int, int>& range) const {
CheckRange(container, range);
ContainerT ret{};
ret->assign(std::next(container->begin(), range.first),
std::next(container->begin(), range.second));
return ret;
}
template <typename ContainerT>
void CheckRange(const ContainerT& container,
const std::pair<int, int>& range) const {
CHECK_GE(range.first, 0);
CHECK_GE(range.second, 0);
CHECK_LE(range.first, container->size());
CHECK_LE(range.second, container->size());
CHECK_LT(range.first, range.second);
}
};
struct SimplifyDotDot {
using source_pattern_type = IndexDotValue<List<Value>, List<std::int64_t>>;
std::int64_t Product(const List<Constant>& dims) {
std::int64_t ret = 1;
for (const auto& dim : *dims) {
CHECK(dim.Has<std::int64_t>());
ret *= dim.Get<std::int64_t>();
}
return ret;
}
Value MatchAndRewrite(const Value& value, const IndexExprInferContext& ctx) {
const auto& [index_dot_values, dot_dims] =
value.Get<IndexDotValue<Value, Constant>>().tuple();
CHECK_EQ(index_dot_values.Get<List<Value>>()->size(),
dot_dims.Get<List<Constant>>()->size());
List<Value> new_dot_values{};
List<Constant> new_dot_dims{};
for (std::size_t i = 0; i < index_dot_values.Get<List<Value>>()->size();
++i) {
const auto& index_dot_value = index_dot_values.Get<List<Value>>()->at(i);
const auto& dot_dim =
dot_dims.Get<List<Constant>>()->at(i).Get<std::int64_t>();
if (Match<source_pattern_type>(index_dot_value)) {
const auto& [sub_index_dot_values, sub_dot_dims] =
index_dot_value.Get<IndexDotValue<Value, Constant>>().tuple();
const auto& sub_dot_dim_values = sub_dot_dims.Get<List<Constant>>();
std::int64_t dim_product = Product(sub_dot_dim_values);
if (dim_product == dot_dim) {
for (std::size_t j = 0;
j < sub_index_dot_values.Get<List<Value>>()->size();
++j) {
const auto& sub_index_dot_value =
sub_index_dot_values.Get<List<Value>>()->at(j);
const auto& sub_dot_dim = sub_dot_dim_values->at(j);
new_dot_values->emplace_back(sub_index_dot_value);
new_dot_dims->emplace_back(sub_dot_dim);
}
} else {
new_dot_values->emplace_back(index_dot_value);
new_dot_dims->emplace_back(dot_dim);
}
} else {
new_dot_values->emplace_back(index_dot_value);
new_dot_dims->emplace_back(dot_dim);
}
}
return IndexDotValue<Value, Constant>{new_dot_values, new_dot_dims};
}
};
// Only simplify top-layer of value
Value SimplifyValue(Value value, const IndexExprInferContext& ctx) {
value = MatchAndRewrite<SimplifyBroadcastedIteratorByReplacingToStaticDim>(
value, ctx);
value = MatchAndRewrite<SimplifyBroadcastedIterator>(value, ctx);
value = MatchAndRewrite<SimplifyDot>(value, ctx);
value = MatchAndRewrite<SimplifyUnDot>(value, ctx);
value = MatchAndRewrite<SimplifyList>(value, ctx);
value = MatchAndRewrite<SimplifyListGetItem>(value, ctx);
value = MatchAndRewrite<SimplifyDotUndot>(value, ctx);
value = MatchAndRewrite<SimplifyUndotDot>(value, ctx);
value = MatchAndRewrite<SimplifyListGetItemList>(value, ctx);
value = MatchAndRewrite<SimplifyGcdShape>(value, ctx);
value = MatchAndRewrite<SimplifyDotDot>(value, ctx);
return value;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
namespace cinn::adt {
Value SimplifyValue(Value value, const IndexExprInferContext& ctx);
}
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
namespace cinn::adt {
DEFINE_ADT_TAG(tIn);
DEFINE_ADT_TAG(tOut);
DEFINE_ADT_TAG(tSSAShadow);
DEFINE_ADT_TAG(tAnchor);
DEFINE_ADT_TAG(tIterator);
DEFINE_ADT_TAG(tIndex);
DEFINE_ADT_TAG(tDim);
DEFINE_ADT_TAG(tOpPlaceHolder);
DEFINE_ADT_TAG(tInMsg);
DEFINE_ADT_TAG(tOutMsg);
DEFINE_ADT_TAG(tBreak);
DEFINE_ADT_TAG(tHasNoConflictValue);
DEFINE_ADT_TAG(tReduceInit);
DEFINE_ADT_TAG(tReduceAcc);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <optional>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/tags.h"
namespace cinn::adt {
// Tree InnerT LeafT = LeafT | InnerT (Tree InnerT LeafT)
template <template <typename> class InnerT, typename LeafT>
DEFINE_ADT_UNION(Tree, LeafT, InnerT<Tree<InnerT, LeafT>>);
// TreeInnerNode T TreeT = (T, [TreeT])
template <typename T>
struct TreeInner {
template <typename TreeT>
struct Node final : public Tuple<T, List<TreeT>> {
using value_type = T;
using Tuple<T, List<TreeT>>::Tuple;
};
};
template <typename T>
struct TreeTrait;
template <template <typename> class InnerT, typename LeafT>
struct TreeTrait<Tree<InnerT, LeafT>> {
using inner_type = InnerT<Tree<InnerT, LeafT>>;
using leaf_type = LeafT;
};
DEFINE_ADT_TAG(tCommon);
DEFINE_ADT_TAG(tLhsRemainder);
DEFINE_ADT_TAG(tRhsRemainder);
template <typename TreeT>
struct TreeMerger {
using tree_type = TreeT;
using inner_type = typename TreeTrait<TreeT>::inner_type;
using leaf_type = typename TreeTrait<TreeT>::leaf_type;
using inner_data_type = typename inner_type::value_type;
inner_data_type GetInnerDataForLeaf(const leaf_type& leaf) const;
inner_type MakeInnerNode(const inner_data_type& inner_data,
const List<TreeT>& children) const;
using MergeResult = std::tuple<tCommon<inner_data_type>,
tLhsRemainder<inner_data_type>,
tRhsRemainder<inner_data_type>>;
MergeResult MergeInnerValue(const inner_data_type& lhs,
const inner_data_type& rhs) const;
};
template <typename TreeMergerT>
List<typename TreeMergerT::tree_type> MergeTwoInnerTree(
const TreeMergerT& tree_merger,
const typename TreeMergerT::tree_type& lhs,
const typename TreeMergerT::tree_type& rhs);
template <typename TreeMergerT>
List<typename TreeMergerT::tree_type> MergeTwoInnerTreeImpl(
const TreeMergerT& tree_merger,
const typename TreeTrait<typename TreeMergerT::tree_type>::inner_type& lhs,
const typename TreeTrait<typename TreeMergerT::tree_type>::inner_type&
rhs) {
using TreeT = typename TreeMergerT::tree_type;
using leaf_type = typename TreeTrait<TreeT>::leaf_type;
using inner_type = typename TreeTrait<TreeT>::inner_type;
using inner_data_type = typename inner_type::value_type;
const auto& [lhs_inner_data, lhs_children] = lhs.tuple();
const auto& [rhs_inner_data, rhs_children] = rhs.tuple();
const auto& [common, lhs_remainder, rhs_remainder] =
tree_merger.MergeInnerValue(lhs_inner_data, rhs_inner_data);
bool is_common_empty = (lhs_remainder.value() == lhs_inner_data &&
rhs_remainder.value() == rhs_inner_data);
if (is_common_empty) {
return List<TreeT>{lhs, rhs};
} else if (common.value() == lhs_inner_data &&
common.value() == rhs_inner_data) {
List<TreeT> merged_children{};
merged_children->insert(
merged_children->end(), lhs_children->begin(), lhs_children->end());
merged_children->insert(
merged_children->end(), rhs_children->begin(), rhs_children->end());
const auto ret = tree_merger.MakeInnerNode(common.value(), merged_children);
return List<TreeT>{ret};
} else if (common.value() == lhs_inner_data &&
common.value() != rhs_inner_data) {
const auto new_rhs =
tree_merger.MakeInnerNode(rhs_remainder.value(), rhs_children);
const TreeT last_lhs_child = lhs_children->back();
const auto merged_last_children =
MergeTwoInnerTree(tree_merger, last_lhs_child, new_rhs);
List<TreeT> new_lhs_children{};
new_lhs_children->insert(new_lhs_children->end(),
lhs_children->begin(),
std::prev(lhs_children->end()));
new_lhs_children->insert(new_lhs_children->end(),
merged_last_children->begin(),
merged_last_children->end());
const auto ret =
tree_merger.MakeInnerNode(common.value(), new_lhs_children);
return List<TreeT>{ret};
} else if (common.value() != lhs_inner_data &&
common.value() == rhs_inner_data) {
const auto new_lhs =
tree_merger.MakeInnerNode(lhs_remainder.value(), lhs_children);
const TreeT first_rhs_child = *rhs_children->begin();
const auto merged_first_children =
MergeTwoInnerTree(tree_merger, new_lhs, first_rhs_child);
List<TreeT> new_rhs_children = merged_first_children;
new_rhs_children->insert(new_rhs_children->end(),
std::next(rhs_children->begin()),
rhs_children->end());
const auto ret =
tree_merger.MakeInnerNode(common.value(), new_rhs_children);
return List<TreeT>{ret};
} else if (common.value() != lhs_inner_data &&
common.value() != rhs_inner_data) {
const auto new_lhs =
tree_merger.MakeInnerNode(lhs_remainder.value(), lhs_children);
const auto new_rhs =
tree_merger.MakeInnerNode(rhs_remainder.value(), rhs_children);
const auto ret = tree_merger.MakeInnerNode(common.value(),
List<TreeT>{new_lhs, new_rhs});
return List<TreeT>{ret};
} else {
LOG(FATAL) << "Dead code";
}
}
template <typename TreeMergerT>
List<typename TreeMergerT::tree_type> MergeTwoInnerTree(
const TreeMergerT& tree_merger,
const typename TreeMergerT::tree_type& lhs,
const typename TreeMergerT::tree_type& rhs) {
using TreeT = typename TreeMergerT::tree_type;
using inner_type = typename TreeTrait<TreeT>::inner_type;
return std::visit(
[&](const auto& lhs, const auto& rhs) -> List<TreeT> {
if constexpr (std::is_same_v<std::decay_t<decltype(lhs)>, inner_type> &&
std::is_same_v<std::decay_t<decltype(rhs)>, inner_type>) {
return MergeTwoInnerTreeImpl(tree_merger, lhs, rhs);
} else {
return List<TreeT>{lhs, rhs};
}
},
lhs.variant(),
rhs.variant());
}
template <typename TreeMergerT>
void MergeTrees(
const TreeMergerT& tree_merger,
List<typename TreeMergerT::tree_type>* acc,
const List<typename TreeTrait<typename TreeMergerT::tree_type>::leaf_type>&
leaves) {
using TreeT = typename TreeMergerT::tree_type;
if (leaves->empty()) {
return;
}
using leaf_type = typename TreeTrait<TreeT>::leaf_type;
using inner_type = typename TreeTrait<TreeT>::inner_type;
using inner_data_type = typename inner_type::value_type;
const auto& MakeTreeFromLeaf = [&](const leaf_type& leaf) -> TreeT {
const inner_data_type inner_data = tree_merger.GetInnerDataForLeaf(leaf);
const auto ret =
tree_merger.MakeInnerNode(inner_data, List<TreeT>{TreeT{leaf}});
return ret;
};
// Handle init
std::size_t leaf_start = 0;
if ((*acc)->empty()) {
(*acc)->emplace_back(MakeTreeFromLeaf(leaves->at(0)));
leaf_start = 1;
}
for (std::size_t i = leaf_start; i < leaves->size(); ++i) {
const auto merged = MergeTwoInnerTree(
tree_merger, (*acc)->back(), MakeTreeFromLeaf(leaves->at(i)));
(*acc)->erase(std::prev((*acc)->end()));
(*acc)->insert((*acc)->end(), merged->begin(), merged->end());
}
}
template <typename TreeMergerT>
List<typename TreeMergerT::tree_type> MakeMergedTrees(
const TreeMergerT& tree_merger,
const List<typename TreeTrait<typename TreeMergerT::tree_type>::leaf_type>&
leaves) {
using TreeT = typename TreeMergerT::tree_type;
List<TreeT> acc{};
MergeTrees(tree_merger, &acc, leaves);
return acc;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/tree.h"
#include "gtest/gtest.h"
namespace cinn::adt {
namespace test {
using IntTreeLeafT = std::vector<int>;
using IntTreeInnerDataT = std::vector<int>;
using IntVecTree = Tree<TreeInner<IntTreeInnerDataT>::Node, IntTreeLeafT>;
using IntTreeInnerT = TreeInner<IntTreeInnerDataT>::template Node<IntVecTree>;
} // namespace test
template <>
struct TreeMerger<test::IntVecTree> {
using tree_type = test::IntVecTree;
using inner_type = typename TreeTrait<test::IntVecTree>::inner_type;
using leaf_type = typename TreeTrait<test::IntVecTree>::leaf_type;
using inner_data_type = typename inner_type::value_type;
inner_data_type GetInnerDataForLeaf(const leaf_type& leaf) const {
return leaf;
}
inner_type MakeInnerNode(const inner_data_type& inner_data,
const List<test::IntVecTree>& children) const {
return inner_type{inner_data, children};
}
using MergeResult = std::tuple<tCommon<inner_data_type>,
tLhsRemainder<inner_data_type>,
tRhsRemainder<inner_data_type>>;
MergeResult MergeInnerValue(const inner_data_type& lhs,
const inner_data_type& rhs) const {
inner_data_type common{};
inner_data_type lhs_remainder{};
inner_data_type rhs_remainder{};
int min_size = std::min(lhs.size(), rhs.size());
int idx = 0;
for (; idx < min_size; ++idx) {
if (lhs.at(idx) == rhs.at(idx)) {
common.emplace_back(lhs.at(idx));
} else {
break;
}
}
for (int lhs_idx = idx; lhs_idx < lhs.size(); ++lhs_idx) {
lhs_remainder.emplace_back(lhs.at(lhs_idx));
}
for (int rhs_idx = idx; rhs_idx < rhs.size(); ++rhs_idx) {
rhs_remainder.emplace_back(rhs.at(rhs_idx));
}
return MergeResult{common, lhs_remainder, rhs_remainder};
}
};
namespace test {
TEST(IntVecTree, naive) {
List<IntTreeLeafT> leaves{IntTreeLeafT{1, 2, 3}, IntTreeLeafT{4, 5, 6}};
TreeMerger<test::IntVecTree> tree_merger{};
List<IntVecTree> ret = MakeMergedTrees(tree_merger, leaves);
ASSERT_EQ(ret->size(), 2);
ASSERT_TRUE(ret->at(0).Has<IntTreeInnerT>());
const auto& [inner_data0, children0] =
ret->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data0 == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->size() == 1));
ASSERT_TRUE((children0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children0->at(0).Get<IntTreeLeafT>() == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE(ret->at(1).Has<IntTreeInnerT>());
const auto& [inner_data1, children1] =
ret->at(1).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data1 == IntTreeLeafT{4, 5, 6}));
ASSERT_TRUE((children1->size() == 1));
ASSERT_TRUE((children1->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children1->at(0).Get<IntTreeLeafT>() == IntTreeLeafT{4, 5, 6}));
}
TEST(IntVecTree, left_equal_right) {
List<IntTreeLeafT> leaves{IntTreeLeafT{1, 2, 3}, IntTreeLeafT{1, 2, 3}};
List<IntVecTree> ret =
MakeMergedTrees(TreeMerger<test::IntVecTree>{}, leaves);
ASSERT_EQ(ret->size(), 1);
ASSERT_TRUE(ret->at(0).Has<IntTreeInnerT>());
const auto& [inner_data0, children0] =
ret->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data0 == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->size() == 2));
ASSERT_TRUE((children0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children0->at(0).Get<IntTreeLeafT>() == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->at(1).Has<IntTreeLeafT>()));
ASSERT_TRUE((children0->at(1).Get<IntTreeLeafT>() == IntTreeLeafT{1, 2, 3}));
}
TEST(IntVecTree, left_gt_right) {
List<IntTreeLeafT> leaves{IntTreeLeafT{1, 2, 3, 4, 5}, IntTreeLeafT{1, 2, 3}};
List<IntVecTree> ret =
MakeMergedTrees(TreeMerger<test::IntVecTree>{}, leaves);
ASSERT_EQ(ret->size(), 1);
ASSERT_TRUE(ret->at(0).Has<IntTreeInnerT>());
const auto& [inner_data0, children0] =
ret->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data0 == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->size() == 2));
ASSERT_TRUE((children0->at(0).Has<IntTreeInnerT>()));
const auto& [inner_data_left0, children_left0] =
children0->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data_left0 == IntTreeLeafT{4, 5}));
ASSERT_TRUE((children_left0->size() == 1));
ASSERT_TRUE((children_left0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children_left0->at(0).Get<IntTreeLeafT>() ==
IntTreeLeafT{1, 2, 3, 4, 5}));
ASSERT_TRUE((children0->at(1).Has<IntTreeLeafT>()));
ASSERT_TRUE((children0->at(1).Get<IntTreeLeafT>() == IntTreeLeafT{1, 2, 3}));
}
TEST(IntVecTree, left_lt_right) {
List<IntTreeLeafT> leaves{IntTreeLeafT{1, 2, 3}, IntTreeLeafT{1, 2, 3, 4, 5}};
List<IntVecTree> ret =
MakeMergedTrees(TreeMerger<test::IntVecTree>{}, leaves);
ASSERT_EQ(ret->size(), 1);
ASSERT_TRUE(ret->at(0).Has<IntTreeInnerT>());
const auto& [inner_data0, children0] =
ret->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data0 == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->size() == 2));
ASSERT_TRUE((children0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children0->at(0).Get<IntTreeLeafT>() == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->at(1).Has<IntTreeInnerT>()));
const auto& [inner_data_right0, children_right0] =
children0->at(1).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data_right0 == IntTreeLeafT{4, 5}));
ASSERT_TRUE((children_right0->size() == 1));
ASSERT_TRUE((children_right0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children_right0->at(0).Get<IntTreeLeafT>() ==
IntTreeLeafT{1, 2, 3, 4, 5}));
}
TEST(IntVecTree, left_ne_right) {
List<IntTreeLeafT> leaves{IntTreeLeafT{1, 2, 3, 4, 5},
IntTreeLeafT{1, 2, 3, 6, 7}};
List<IntVecTree> ret =
MakeMergedTrees(TreeMerger<test::IntVecTree>{}, leaves);
ASSERT_EQ(ret->size(), 1);
ASSERT_TRUE(ret->at(0).Has<IntTreeInnerT>());
const auto& [inner_data0, children0] =
ret->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data0 == IntTreeLeafT{1, 2, 3}));
ASSERT_TRUE((children0->size() == 2));
ASSERT_TRUE((children0->at(0).Has<IntTreeInnerT>()));
const auto& [inner_data_left0, children_left0] =
children0->at(0).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data_left0 == IntTreeLeafT{4, 5}));
ASSERT_TRUE((children_left0->size() == 1));
ASSERT_TRUE((children_left0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children_left0->at(0).Get<IntTreeLeafT>() ==
IntTreeLeafT{1, 2, 3, 4, 5}));
ASSERT_TRUE((children0->at(1).Has<IntTreeInnerT>()));
const auto& [inner_data_right0, children_right0] =
children0->at(1).Get<IntTreeInnerT>().tuple();
ASSERT_TRUE((inner_data_right0 == IntTreeLeafT{6, 7}));
ASSERT_TRUE((children_right0->size() == 1));
ASSERT_TRUE((children_right0->at(0).Has<IntTreeLeafT>()));
ASSERT_TRUE((children_right0->at(0).Get<IntTreeLeafT>() ==
IntTreeLeafT{1, 2, 3, 6, 7}));
}
} // namespace test
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <cstddef>
#include <functional>
namespace cinn::adt {
class UniqueId final {
public:
UniqueId() : unique_id_(NewSeqNumber()) {}
UniqueId(const UniqueId&) = default;
UniqueId(UniqueId&&) = default;
UniqueId& operator=(const UniqueId&) = default;
UniqueId& operator=(UniqueId&&) = default;
static UniqueId New() { return UniqueId{NewSeqNumber()}; }
bool operator==(const UniqueId& other) const {
return this->unique_id_ == other.unique_id_;
}
bool operator!=(const UniqueId& other) const {
return !this->operator==(other);
}
bool operator<(const UniqueId& other) const {
return this->unique_id_ < other.unique_id_;
}
std::size_t unique_id() const { return unique_id_; }
private:
static std::size_t NewSeqNumber() {
static std::atomic<std::size_t> seq_number{0};
return ++seq_number;
}
explicit UniqueId(std::size_t unique_id) : unique_id_(unique_id) {}
std::size_t unique_id_;
};
} // namespace cinn::adt
namespace std {
template <>
struct hash<cinn::adt::UniqueId> final {
std::size_t operator()(const cinn::adt::UniqueId& unique_id) const {
return unique_id.unique_id();
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/write_broadcast_disabled_bidirection_equation_generator.h"
#include "paddle/cinn/adt/equation_graph.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/naive_equation_function_constants_provider.h"
namespace cinn::adt {
namespace {
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
template <
typename DoEachT /*: void(&)(std::size_t, OpStmt, OpEquationContext)*/>
void VisitEachOpStmtAndEquationCtx(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const DoEachT& DoEach) {
for (std::size_t i = 0; i < op_stmts->size(); ++i) {
const auto& ctx = EquationCtx4OpStmt(op_stmts->at(i));
DoEach(i, op_stmts->at(i), ctx);
}
}
using InMsg2OutMsgT = InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>;
std::unordered_map<Variable, const Value> MakeAnchorIndex2Ok(
const Index& anchor_index) {
return {{anchor_index, Ok{}}};
}
bool LocalEquationsSolvable(
const GraphView& graph_view,
const Index& anchor_index,
const FakeOpPlaceHolder& fake_op_placeholder,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider) {
const auto& init_var2value = MakeAnchorIndex2Ok(anchor_index);
IndexExprInferContext ctx{init_var2value, constants_provider};
bool has_no_conflict_value =
TrySolveEquations(graph_view, anchor_index, &ctx).value();
return has_no_conflict_value && ctx.HasValue(fake_op_placeholder);
}
List<std::optional<Index>> GetMaskedOutIndexes(
const List<Index>& in_msg_out_indexes,
const List<std::optional<Index>>& out_msg_out_indexes,
const std::vector<Index>& erased_in_msg_out_tensor_indexes) {
List<std::optional<Index>> ret{};
const auto& erased = erased_in_msg_out_tensor_indexes;
CHECK_EQ(in_msg_out_indexes->size(), out_msg_out_indexes->size());
for (std::size_t i = 0; i < in_msg_out_indexes->size(); ++i) {
const auto& in_msg_index = in_msg_out_indexes->at(i);
if (std::find(erased.begin(), erased.end(), in_msg_index) == erased.end()) {
ret->emplace_back(out_msg_out_indexes->at(i));
} else {
ret->emplace_back(std::nullopt);
}
}
return ret;
}
Equation EraseIndexes(
const Equation& equation,
const std::vector<Index>& erased_in_msg_out_tensor_indexes) {
const auto& in_msg2out_msg = equation.Get<InMsg2OutMsgT>();
const auto& [op_placeholder, out_msg_indexes, in_msg_indexes] =
in_msg2out_msg.tuple();
const auto& [_, in_msg_out_indexes] = in_msg_indexes.value().tuple();
const auto& [out_msg_in_indexes, out_msg_out_indexes] =
out_msg_indexes.value().tuple();
const auto& masked_out_indexes =
GetMaskedOutIndexes(in_msg_out_indexes.value(),
out_msg_out_indexes.value(),
erased_in_msg_out_tensor_indexes);
OpArgIndexes<std::optional<Index>> new_out_msg_indexes{out_msg_in_indexes,
masked_out_indexes};
Equation ret_equation =
InMsg2OutMsgT{op_placeholder, new_out_msg_indexes, in_msg_indexes};
return ret_equation;
}
std::vector<Index> GenerateWriteBroadcastTensorIndexs(
const std::shared_ptr<config::NaiveOpEquationContext>& ctx,
const Equations& in_msg2out_msg_equations,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider) {
const auto& eqaution_graph_view =
Graph::New(ctx->equations())->GetGraphView();
GraphView graph_view = eqaution_graph_view.Merge(
Graph::New(in_msg2out_msg_equations)->GetGraphView());
std::vector<Index> ret{};
const auto& fake_op_placeholder = ctx->fake_op_placeholder();
ctx->VisitEachOutputTensorIndex([&](const auto& out_index) {
if (!LocalEquationsSolvable(
graph_view, out_index, fake_op_placeholder, constants_provider)) {
ret.emplace_back(out_index);
}
});
return ret;
}
} // namespace
Equations
WriteBroadcastDisabledBidirectionEquationGenerator::GetDirectionEquations()
const {
std::shared_ptr<const EquationFunctionConstantsProvider> constants_provider{
new NaiveEquationFunctionConstantsProvider{
naive_bidirection_equation_generator_.op_stmts(),
naive_bidirection_equation_generator_.EquationCtx4OpStmt()}};
Equations ret{};
VisitEachOpStmtAndEquationCtx(
naive_bidirection_equation_generator_.op_stmts(),
naive_bidirection_equation_generator_.EquationCtx4OpStmt(),
[&](std::size_t idx,
const OpStmt& op_stmt,
const std::shared_ptr<config::NaiveOpEquationContext>& ctx) {
const auto& in_msg2out_msg_equations =
naive_bidirection_equation_generator_.equations();
const auto& truncated_output_tensor_idxes =
GenerateWriteBroadcastTensorIndexs(
ctx, in_msg2out_msg_equations, constants_provider);
ret->emplace_back(EraseIndexes(in_msg2out_msg_equations->at(idx),
truncated_output_tensor_idxes));
});
return ret;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/cinn/adt/naive_bidirection_equation_generator.h"
namespace cinn::adt {
class WriteBroadcastDisabledBidirectionEquationGenerator final
: public DirectionEquationGenerator {
public:
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
WriteBroadcastDisabledBidirectionEquationGenerator(
const WriteBroadcastDisabledBidirectionEquationGenerator&) = delete;
WriteBroadcastDisabledBidirectionEquationGenerator(
WriteBroadcastDisabledBidirectionEquationGenerator&&) = delete;
WriteBroadcastDisabledBidirectionEquationGenerator(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt)
: naive_bidirection_equation_generator_(op_stmts, EquationCtx4OpStmt) {}
Equations GetDirectionEquations() const override;
std::function<const OpStmt*(const FakeOpPlaceHolder&)>
MakeGetterOpStmt4OpPlaceHolder() const override {
return naive_bidirection_equation_generator_
.MakeGetterOpStmt4OpPlaceHolder();
}
std::optional<Index> OutMsgIndex4InMsgIndex(
const Index& index) const override {
return naive_bidirection_equation_generator_.OutMsgIndex4InMsgIndex(index);
}
private:
NaiveBidirectionEquationGenerator naive_bidirection_equation_generator_;
};
} // namespace cinn::adt
core_gather_headers()
gather_srcs(cinnapi_src SRCS ast_gen.cc tensor_group.cc)
cinn_cc_test(test_ast_gen_ius SRCS ast_gen_test.cc DEPS cinncore)
cinn_cc_test(test_tensor_group SRCS tensor_group_test.cc DEPS cinncore)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ast_gen_ius/ast_gen.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace cinn {
namespace ast_gen_ius {
ir::Expr ConvertReduceBody(ir::Expr body,
ir::Tensor tensor,
const std::vector<Expr>& axis_exprs) {
ir::Reduce* reduce_node = body.As<ir::Reduce>();
if (!reduce_node) {
return ir::Store::Make(tensor, body, axis_exprs);
}
switch (reduce_node->reduce_type) {
case ir::Reduce::kSum:
return ir::Store::Make(
tensor, tensor(axis_exprs) + reduce_node->body, axis_exprs);
case ir::Reduce::kMul:
return ir::Store::Make(
tensor, tensor(axis_exprs) * reduce_node->body, axis_exprs);
case ir::Reduce::kMax:
return ir::Store::Make(
tensor,
ir::Max::Make(tensor(axis_exprs), reduce_node->body),
axis_exprs);
case ir::Reduce::kMin:
return ir::Store::Make(
tensor,
ir::Min::Make(tensor(axis_exprs), reduce_node->body),
axis_exprs);
case ir::Reduce::kAll:
return ir::Store::Make(
tensor, tensor(axis_exprs) && reduce_node->body, axis_exprs);
case ir::Reduce::kAny:
return ir::Store::Make(
tensor, tensor(axis_exprs) || reduce_node->body, axis_exprs);
default:
CINN_NOT_IMPLEMENTED
}
}
ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
const std::vector<ir::Var>& axis = tensor->axis();
const std::vector<ir::Expr>& shape = tensor->shape;
size_t axis_len = axis.size();
CHECK_EQ(shape.size(), axis_len) << "Internal Error: Tensor has different "
"shape and axis length in AstGen";
std::vector<ir::Expr> axis_exprs;
for (const auto& a : axis) {
axis_exprs.push_back(a);
}
if (tensor->is_reduce_tensor()) {
// Make an init Tensor for domain without reduce axis
Expr init_value = tensor->GetReduceInitVal();
// TODO(zhhsplendid): Clean the handcoded "__reduce_init" string
std::string reduce_init_name = tensor->name + "__reduce_init";
const std::vector<Expr>& domain = tensor->domain_without_reduce_axis();
ir::Tensor init_tensor = lang::Compute(
domain,
[=](const std::vector<Expr>& axis) { return init_value; },
reduce_init_name);
tensor_group->Insert(init_tensor);
tensor_group->MarkShareMemBuffer(tensor, init_tensor);
tensor_group->CtrlDepend(tensor, init_tensor);
Expr init_body = ir::Store::Make(init_tensor, init_value, axis_exprs);
// create schedule block itervars, i0,i1...
std::vector<ir::Var> block_vars;
std::vector<ir::Expr> iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
std::vector<Var> axis_vars = common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
block_vars.push_back(Var(Expr(0),
shape[i],
cinn::UniqName("i" + std::to_string(i)),
/*is_reduce = */ false));
optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars[i]);
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
}
}
init_body = ir::ScheduleBlockRealize::Make(
iter_values,
ir::ScheduleBlock::Make(
block_vars, {}, {}, reduce_init_name, init_body));
// For the remaining reduce axis, make reduce body
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
ir::Expr reduce_body =
ConvertReduceBody(tensor->body(), tensor, axis_exprs);
// create schedule block itervars, i0,i1...
std::vector<ir::Var> reduce_block_vars;
std::vector<ir::Expr> reduce_iter_values;
// reduce body and reduce init schedule block should have different objects
// for same axis so we re-create objects
std::vector<Var> reduce_axis_vars = common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
reduce_block_vars.push_back(Var(Expr(0),
shape[i],
cinn::UniqName("i" + std::to_string(i)),
/*is_reduce = */ false));
reduce_axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
reduce_iter_values.push_back(Expr(0));
} else {
reduce_iter_values.push_back(axis_vars[i]);
}
}
for (int i = 0; i < reduce_axis.size(); ++i) {
int count = shape.size() + i;
reduce_block_vars.push_back(
Var(reduce_axis[i]->lower_bound,
reduce_axis[i]->upper_bound,
cinn::UniqName("i" + std::to_string(count)),
/*is_reduce = */ true));
ir::Var reduce_axis_var = reduce_axis[i];
reduce_axis_var->is_reduce_axis = true;
reduce_iter_values.push_back(reduce_axis_var);
}
for (int i = 0; i < axis.size(); ++i) {
optim::ReplaceVarWithExpr(&reduce_body, axis[i], reduce_block_vars[i]);
}
for (int i = axis.size(); i < reduce_block_vars.size(); ++i) {
optim::ReplaceVarWithExpr(
&reduce_body, reduce_axis[i - axis.size()], reduce_block_vars[i]);
}
reduce_body = ir::ScheduleBlockRealize::Make(
reduce_iter_values,
ir::ScheduleBlock::Make(
reduce_block_vars, {}, {}, tensor->name, reduce_body));
for (int i = static_cast<int>(reduce_axis.size()) - 1; i >= 0; --i) {
reduce_body = ir::For::Make(reduce_axis[i],
reduce_axis[i]->lower_bound,
reduce_axis[i]->upper_bound,
ir::ForType::Serial,
ir::DeviceAPI::Host,
ir::Block::Make({reduce_body}));
}
// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
ir::Var loop_var = axis[i];
ir::Expr loop_extent = shape[i];
body = ir::For::Make(
loop_var,
Expr(0),
loop_extent,
ir::ForType::Serial,
ir::DeviceAPI::Host,
i == static_cast<int>(axis_len) - 1 ? body : ir::Block::Make({body}));
}
return body;
} else {
ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs);
// create schedule block itervars, i0,i1...
std::vector<ir::Var> block_vars;
std::vector<ir::Expr> iter_values;
std::vector<Var> axis_vars = common::GenDefaultAxis(axis_len);
for (int i = 0; i < shape.size(); ++i) {
block_vars.push_back(Var(
Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false));
optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]);
axis_vars[i]->is_reduce_axis = false;
if (shape[i] == Expr(1)) {
iter_values.push_back(Expr(0));
} else {
iter_values.push_back(axis_vars[i]);
}
}
body = ir::ScheduleBlockRealize::Make(
iter_values,
ir::ScheduleBlock::Make(block_vars, {}, {}, tensor->name, body));
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
ir::Var loop_var = axis[i];
ir::Expr loop_extent = shape[i];
body = ir::For::Make(loop_var,
Expr(0),
loop_extent,
ir::ForType::Serial,
ir::DeviceAPI::Host,
ir::Block::Make({body}));
}
return body;
}
}
} // namespace ast_gen_ius
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
namespace cinn {
namespace ast_gen_ius {
class AstGen {
public:
static ir::Expr Build(const ir::Tensor& tensor, TensorGroup* tensor_group);
};
} // namespace ast_gen_ius
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <vector>
#include "paddle/cinn/ast_gen_ius/ast_gen.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/placeholder.h"
namespace cinn {
namespace ast_gen_ius {
using cinn::ir::Expr;
using cinn::ir::Tensor;
TEST(AstGen, Build) {
std::vector<Expr> shape = {Expr(10), Expr(10), Expr(10), Expr(10)};
lang::Placeholder<float> A("A", shape);
Tensor B = lang::Compute(
shape,
[&](const std::vector<Expr>& indice) { return lang::Relu(A(indice), 0); },
"relu_test");
TensorGroup tensor_group({B});
Expr out = AstGen::Build(B, &tensor_group);
LOG(INFO) << out;
}
} // namespace ast_gen_ius
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include <unordered_map>
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace ast_gen_ius {
TensorGroup::TensorGroup(const std::vector<ir::Tensor>& tensors) {
for (const ir::Tensor& tensor : tensors) {
output_tensor_names_.insert(tensor->name);
this->Insert(tensor);
}
}
void TensorGroup::ShowLog() const {
VLOG(6) << "Showing log for TensorGroup";
for (auto& p : name_to_tensor_) {
VLOG(6) << "Tensor name = " << p.first << " depends on {";
if (ctrl_dep_.count(p.first)) {
for (auto& dep_name : ctrl_dep_.at(p.first)) {
VLOG(6) << dep_name;
}
}
VLOG(6) << "}";
}
}
TensorGroup::TensorGroup(
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
for (const auto& map_pair : tensor_map) {
const ir::Tensor& tensor = map_pair.second;
output_tensor_names_.insert(tensor->name);
this->Insert(tensor);
}
}
TensorGroup::~TensorGroup() {}
bool TensorGroup::Contain(const std::string& name) const {
return name_to_tensor_.find(name) != name_to_tensor_.end();
}
void TensorGroup::Insert(const ir::Tensor& tensor) {
if (!name_to_tensor_.count(tensor->name)) {
name_to_tensor_.insert({tensor->name, tensor});
}
// Using set to de-duplicate
std::set<ir::Tensor> dep_tensors;
std::set<ir::Expr> used_tensors = ir::ir_utils::CollectIRNodes(
tensor->body(), [](const Expr* x) { return x->as_tensor(); });
for (const Expr& x : used_tensors) {
const ir::Tensor to_dep = x.as_tensor_ref();
dep_tensors.insert(to_dep);
this->CtrlDepend(tensor, to_dep);
}
for (const ir::Tensor& t : dep_tensors) {
this->Insert(t);
}
}
ir::Tensor TensorGroup::Get(const std::string& name) {
return name_to_tensor_[name];
}
std::set<ir::Tensor> TensorGroup::GetAllTensors() {
std::set<ir::Tensor> all_tensors;
for (const std::pair<std::string, ir::Tensor>& p : name_to_tensor_) {
all_tensors.insert(p.second);
}
return all_tensors;
}
std::vector<ir::Tensor> TensorGroup::GetGenFuncTopoOrder(
const std::vector<ir::Tensor>& func_args) {
std::unordered_map<std::string, int> in_degree;
for (const auto& dep_pair : ctrl_dep_) {
const std::unordered_set<std::string>& dep_tensor_names = dep_pair.second;
in_degree[dep_pair.first] = dep_tensor_names.size();
VLOG(6) << "indegree[" << dep_pair.first
<< "] = " << dep_tensor_names.size();
}
std::vector<ir::Tensor> ret;
// Using set instead of vector/stack in order to get fix alaphbeta order topo
std::set<std::string> node_set;
for (const auto& name_tensor : name_to_tensor_) {
if (!in_degree.count(name_tensor.first)) {
node_set.insert(name_tensor.first);
}
}
std::set<std::string> input_arg_names;
for (const ir::Tensor& arg : func_args) {
input_arg_names.insert(arg->name);
}
for (const std::string& name : output_tensor_names_) {
input_arg_names.erase(name);
}
while (!node_set.empty()) {
const std::string cur = *(node_set.begin());
node_set.erase(node_set.begin());
if (!input_arg_names.count(cur)) {
ret.push_back(name_to_tensor_[cur]);
}
for (const auto& dep_pair : ctrl_dep_) {
const std::unordered_set<std::string>& dep_tensor_names = dep_pair.second;
if (dep_tensor_names.count(cur)) {
--in_degree[dep_pair.first];
if (in_degree[dep_pair.first] == 0) {
node_set.insert(dep_pair.first);
}
}
}
}
return ret;
}
void TensorGroup::CtrlDepend(const ir::Tensor& tensor,
const ir::Tensor& to_dep) {
ctrl_dep_[tensor->name].insert(to_dep->name);
if (!name_to_tensor_.count(tensor->name)) {
name_to_tensor_[tensor->name] = tensor;
}
if (!name_to_tensor_.count(to_dep->name)) {
name_to_tensor_[to_dep->name] = to_dep;
}
}
std::set<ir::Tensor> TensorGroup::GetCrtlDepTensors(
const std::string& tensor_name) {
if (!ctrl_dep_.count(tensor_name)) {
return {};
}
std::set<ir::Tensor> ret;
for (const std::string& dep_name : ctrl_dep_[tensor_name]) {
ret.insert(name_to_tensor_[dep_name]);
}
return ret;
}
std::string TensorGroup::GetShareMemRootName(const std::string& tensor_name) {
if (!share_memory_tensor_.count(tensor_name)) {
share_memory_tensor_[tensor_name] = tensor_name;
return tensor_name;
}
if (share_memory_tensor_[tensor_name] == tensor_name) {
return tensor_name;
}
share_memory_tensor_[tensor_name] =
GetShareMemRootName(share_memory_tensor_[tensor_name]);
return share_memory_tensor_[tensor_name];
}
void TensorGroup::MarkShareMemBuffer(const ir::Tensor& tensor,
const ir::Tensor& to_share) {
share_memory_tensor_[GetShareMemRootName(to_share->name)] =
GetShareMemRootName(tensor->name);
}
absl::flat_hash_map<std::string, ir::Tensor> TensorGroup::AllocateBuffers() {
std::unordered_set<std::string> allocated_roots;
for (auto& name_tensor : name_to_tensor_) {
std::string root_name = GetShareMemRootName(name_tensor.first);
// Allocate root buffer
if (!allocated_roots.count(root_name)) {
ir::Tensor root_tensor = name_to_tensor_[root_name];
if (!root_tensor->buffer.defined() && !root_tensor->type().is_void()) {
root_tensor->WithBuffer();
VLOG(6) << "Bind root_tensor " << root_name << " with buffer "
<< root_tensor->buffer->name;
}
allocated_roots.insert(root_name);
}
// Share buffer
if (root_name != name_tensor.first) {
ir::Tensor& root_tensor = name_to_tensor_[root_name];
ir::Tensor& tensor = name_tensor.second;
auto keep_shape = root_tensor->buffer->shape;
tensor->Bind(root_tensor->buffer);
root_tensor->buffer->shape = keep_shape;
tensor->buffer->shape = keep_shape;
VLOG(6) << "Share buffer " << root_name << " with " << name_tensor.first;
}
}
return name_to_tensor_;
}
void StageMapShareMemory(const poly::StageMap& stages) {
absl::flat_hash_map<std::string, ir::_Tensor_*> tensor_map;
for (auto& stage : stages) {
tensor_map[stage.second->tensor()->name] = stage.second->tensor();
}
for (auto& stage : stages) {
if (!stage.second->tensor()->buffer.defined() &&
!stage.second->meta.tensors_to_share_buffer_with.empty()) {
for (auto& str : stage.second->meta.tensors_to_share_buffer_with) {
if (tensor_map[str]->buffer.defined()) {
auto edited_shape = tensor_map[str]->buffer->shape;
stage.second->tensor()->Bind(tensor_map[str]->buffer);
tensor_map[str]->buffer->shape = edited_shape;
VLOG(3) << "Stage Tensor " << stage.second->tensor()->name
<< " bind buffer to " << tensor_map[str]->name << " , "
<< tensor_map[str]->buffer->name;
}
}
}
}
}
TensorGroup ConvertStageMapToTensorGroup(const poly::StageMap& stage_map) {
std::vector<ir::Tensor> stage_tensors;
std::set<ir::Tensor> reshape_tensors;
for (auto iter = stage_map.begin(); iter != stage_map.end(); ++iter) {
if (iter->second->has_expression()) {
const std::string& tensor_name = iter->first;
stage_tensors.push_back(ir::Tensor(iter->second->tensor()));
if (utils::Endswith(tensor_name, "_reshape")) {
reshape_tensors.insert(ir::Tensor(iter->second->tensor()));
}
}
}
ast_gen_ius::TensorGroup tensor_group(stage_tensors);
StageMapShareMemory(stage_map);
return tensor_group;
}
} // namespace ast_gen_ius
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace ast_gen_ius {
/**
* Collection which maintains the relation between Tensor(s) such as control
* dependency, memory sharing ... it is used in AST generation
*/
class TensorGroup {
public:
/**
* Constructor for a TensorGroup, the argument tensors should be output tensor
* arguments of the AST body to be generated. The dependent tensors of the
* output tensors will be collected during construction.
*/
explicit TensorGroup(const std::vector<ir::Tensor>& tensors);
/**
* Constructor for a TensorGroup, the argument tensors should be output tensor
* arguments of the AST body to be generated. The dependent tensors of the
* output tensors will be collected during construction.
*/
explicit TensorGroup(
const std::unordered_map<std::string, ir::Tensor>& tensor_map);
/**
* Destructor.
*/
~TensorGroup();
void ShowLog() const;
/**
* Returns true if TensorGroup collection contains a tensor with input name.
*/
bool Contain(const std::string& name) const;
/**
* Insert a Tensor into TensorGroup collection.
*/
void Insert(const ir::Tensor& tensor);
/**
* Returns the Tensor in TensorGroup collection with the given name.
*/
ir::Tensor Get(const std::string& name);
/**
* Returns all Tensors in TensorGroup.
*/
std::set<ir::Tensor> GetAllTensors();
/**
* Mark `tensor` depends on `to_dep`.
*/
void CtrlDepend(const ir::Tensor& tensor, const ir::Tensor& to_dep);
/**
* Get all tensors which the tensor with given name depends on.
*/
std::set<ir::Tensor> GetCrtlDepTensors(const std::string& tensor_name);
/**
* Get Union-Find set algorithm root tensor name which shares memory with the
* tensor whose name is the input.
*/
std::string GetShareMemRootName(const std::string& tensor_name);
/**
* Mark two tensors share memory, it only marks using Union-Find set
* algorithm, doesn't do really memory sharing/allocation
*/
void MarkShareMemBuffer(const ir::Tensor& tensor, const ir::Tensor& to_share);
/**
* Allocate buffers for Tensors in TensorGroup, it handles the shared memory
* using Union-Find set algorithm.
*/
absl::flat_hash_map<std::string, ir::Tensor> AllocateBuffers();
/**
* Returns tensors in topological order and remove those args
* Becuase the order is used for generating function body, we don't have to
* generate args
*/
std::vector<ir::Tensor> GetGenFuncTopoOrder(
const std::vector<ir::Tensor>& func_args = {});
private:
/** collection of output tensor names */
std::set<std::string> output_tensor_names_;
/** collection of all tensors in this TensorGroup */
absl::flat_hash_map<std::string, ir::Tensor> name_to_tensor_;
/** Stores vector of tensor names, which the key tensor depends on */
std::unordered_map<std::string, std::unordered_set<std::string>> ctrl_dep_;
/**
* Keeps Union Find Set style, each tensor name whose buffer is shared, maps
* to the same name tensor.
*/
std::unordered_map<std::string, std::string> share_memory_tensor_;
};
// TODO(zhhsplendid): remove stage_map need to change all fcompute CINNValuePack
// we will change it in the next PR
TensorGroup ConvertStageMapToTensorGroup(const poly::StageMap& stage_map);
} // namespace ast_gen_ius
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <absl/container/flat_hash_map.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/placeholder.h"
namespace cinn {
namespace ast_gen_ius {
using ir::Expr;
using ir::Tensor;
using ir::Var;
using lang::Compute;
using lang::Placeholder;
TEST(TensorGroup, Easy) {
auto M = Expr(100);
auto N = Expr(15);
Placeholder<float> A("A", {M, N});
Tensor B = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B");
TensorGroup tensor_group({B});
ASSERT_TRUE(tensor_group.Contain("A"));
ASSERT_TRUE(tensor_group.Contain("B"));
ASSERT_EQ(tensor_group.Get("B")->name, "B");
ASSERT_EQ(tensor_group.Get("A")->name, "A");
ASSERT_EQ(tensor_group.GetAllTensors().size(), 2UL);
ASSERT_EQ(tensor_group.GetCrtlDepTensors("A").size(), 0UL);
ASSERT_EQ(tensor_group.GetCrtlDepTensors("B").size(), 1UL);
ASSERT_TRUE(tensor_group.GetCrtlDepTensors("B").count(A));
std::vector<ir::Tensor> topo_tensors =
tensor_group.GetGenFuncTopoOrder({A.tensor(), B});
ASSERT_EQ(topo_tensors.size(), 1UL);
ASSERT_EQ(topo_tensors[0]->name, "B");
ASSERT_EQ(tensor_group.GetShareMemRootName("A"), "A");
ASSERT_EQ(tensor_group.GetShareMemRootName("B"), "B");
tensor_group.MarkShareMemBuffer(tensor_group.Get("A"), tensor_group.Get("B"));
absl::flat_hash_map<std::string, ir::Tensor> buffered_tensors =
tensor_group.AllocateBuffers();
ASSERT_EQ(buffered_tensors["A"]->buffer->name,
buffered_tensors["B"]->buffer->name);
}
TEST(TensorGroup, GraphTopo) {
auto M = Expr(16);
auto N = Expr(16);
/*
* A B
* / \ /
* C D
* \ /
* E
*/
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});
Tensor C = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "C");
Tensor D = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + B(i, j); }, "D");
Tensor E = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return C(i, j) / D(i, j); }, "E");
TensorGroup tensor_group({C, D, E});
std::vector<std::string> check_names = {"A", "B", "C", "D", "E"};
ASSERT_EQ(tensor_group.GetAllTensors().size(), check_names.size());
for (const std::string& name : check_names) {
ASSERT_TRUE(tensor_group.Contain(name));
ASSERT_EQ(tensor_group.Get(name)->name, name);
}
ASSERT_TRUE(tensor_group.GetCrtlDepTensors("E").count(D));
ASSERT_TRUE(tensor_group.GetCrtlDepTensors("E").count(C));
ASSERT_TRUE(tensor_group.GetCrtlDepTensors("D").count(A));
ASSERT_TRUE(tensor_group.GetCrtlDepTensors("D").count(B));
ASSERT_TRUE(tensor_group.GetCrtlDepTensors("C").count(A));
std::vector<ir::Tensor> topo_tensors = tensor_group.GetGenFuncTopoOrder();
ASSERT_EQ(topo_tensors.size(), check_names.size());
for (size_t i = 0; i < check_names.size(); ++i) {
ASSERT_EQ(topo_tensors[i]->name, check_names[i]);
}
std::vector<ir::Tensor> topo_except_argu =
tensor_group.GetGenFuncTopoOrder({A.tensor(), B.tensor()});
ASSERT_EQ(topo_except_argu.size(), 3);
for (int i = 0; i < 3; ++i) {
ASSERT_EQ(topo_except_argu[i]->name, check_names[i + 2]);
}
for (size_t i = 0; i < check_names.size(); ++i) {
ASSERT_EQ(tensor_group.GetShareMemRootName(check_names[i]), check_names[i]);
}
tensor_group.MarkShareMemBuffer(tensor_group.Get("A"), tensor_group.Get("B"));
tensor_group.MarkShareMemBuffer(tensor_group.Get("B"), tensor_group.Get("C"));
tensor_group.MarkShareMemBuffer(tensor_group.Get("C"), tensor_group.Get("D"));
ASSERT_EQ(tensor_group.GetShareMemRootName("A"),
tensor_group.GetShareMemRootName("D"));
absl::flat_hash_map<std::string, ir::Tensor> buffered_tensors =
tensor_group.AllocateBuffers();
ASSERT_EQ(buffered_tensors["A"]->buffer->name,
buffered_tensors["D"]->buffer->name);
}
} // namespace ast_gen_ius
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment