Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/m_expr.h"
namespace cinn::adt {
template <template <typename> class MapT>
struct InlineTranslatorTrait;
template <>
struct InlineTranslatorTrait<MapStmt> final {
template <typename T>
static List<T> GetTreeInnerNodeChildren(const MapStmt<T>& map_stmt) {
const auto& [iterators, stmts] = map_stmt.tuple();
return stmts;
}
template <typename SrcTreeT, typename DstTreeT>
static MapStmt<DstTreeT> ConvertMap(const MapStmt<SrcTreeT>& src_map,
const List<DstTreeT>& dst_children) {
const auto& [iterators, src_children] = src_map.tuple();
return MapStmt<DstTreeT>{iterators, dst_children};
}
};
// OpCall T = (Op, [T])
template <>
struct InlineTranslatorTrait<OpCall> final {
template <typename T>
static List<T> GetTreeInnerNodeChildren(const OpCall<T>& op_call) {
const auto& [op, tensors] = op_call.tuple();
return tensors;
}
template <typename SrcTreeT, typename DstTreeT>
static OpCall<DstTreeT> ConvertMap(const OpCall<SrcTreeT>& src_map,
const List<DstTreeT>& dst_children) {
const auto& [op, _] = src_map.tuple();
return OpCall<DstTreeT>{op, dst_children};
}
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/kgroup.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/igroup.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/m_ir.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
#include "paddle/cinn/adt/schedule_dim.h"
#include "paddle/cinn/hlir/framework/graph.h"
namespace cinn::adt {
using AnchorTensor = Variable;
namespace {
std::size_t GetTensorNumel(const Tensor& tensor) {
CHECK(tensor.Has<adapter::Tensor>());
return tensor.Get<adapter::Tensor>().GetNumel();
}
std::vector<int32_t> GetTensorShape(const Tensor& tensor) {
CHECK(tensor.Has<adapter::Tensor>());
return tensor.Get<adapter::Tensor>().GetShape();
}
} // namespace
List<LoopSize> KGroup::GetDefaultScheduleSizes(
const std::shared_ptr<IGroup>& igroup) const {
List<LoopSize> ret{};
const Tensor& tensor = igroup->anchor_tensor();
const std::vector<int32_t> tensor_shape = GetTensorShape(tensor);
for (int32_t dim : tensor_shape) {
ret->emplace_back(LoopSize{dim});
}
return ret;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/cinn/adt/m_expr.h"
namespace cinn::hlir::framework::pir {
struct Group;
} // namespace cinn::hlir::framework::pir
namespace cinn::adt {
class IGroup;
using cinn::adt::LoopDescriptors;
/**
* Kernel = KGroup = List<IGroup>.
* KGroup is a list of IGroups, KGroup uses shardable dimension to concatenate
* all ops. This dimension is shared by all IGroups and bound to BlockIdx.
*/
class KGroup final {
public:
explicit KGroup(
const std::shared_ptr<hlir::framework::pir::Group>& cinn_group,
const std::vector<std::shared_ptr<IGroup>>& igroups)
: cinn_group_(cinn_group), igroups_(igroups) {}
std::shared_ptr<hlir::framework::pir::Group> cinn_group() const {
return CHECK_NOTNULL(cinn_group_.lock());
}
const std::shared_ptr<IGroup>& GetSoleIGroup() const {
return igroups_.at(0);
}
const std::vector<std::shared_ptr<IGroup>>& igroups() const {
return igroups_;
}
List<LoopSize> GetDefaultScheduleSizes(
const std::shared_ptr<IGroup>& igroup) const;
private:
std::weak_ptr<hlir::framework::pir::Group> cinn_group_;
// NOTE: Use single igroup temporarily. Actually KGroup contains
// multiple IGroups
std::vector<std::shared_ptr<IGroup>> igroups_;
// TODO(Hongyu Jia): Add equations here to link igroups
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
namespace cinn::adt {
DEFINE_ADT_BINARY(EQ);
DEFINE_ADT_BINARY(LT);
DEFINE_ADT_BINARY(GT);
DEFINE_ADT_BINARY(NE);
DEFINE_ADT_BINARY(GE);
DEFINE_ADT_BINARY(LE);
DEFINE_ADT_BINARY(And);
DEFINE_ADT_BINARY(Or);
DEFINE_ADT_UNARY(Not);
// Logical T = EQ T T
// | LT T T
// | GT T T
// | NE T T
// | GE T T
// | LE T T
// | And (Logical T) (Logical T)
// | Or (Logical T) (Logical T)
// | Not (Logical T)
template <typename ValueT>
DEFINE_ADT_UNION(Logical,
EQ<ValueT, ValueT>,
LT<ValueT, ValueT>,
GT<ValueT, ValueT>,
NE<ValueT, ValueT>,
GE<ValueT, ValueT>,
LE<ValueT, ValueT>,
And<Logical<ValueT>, Logical<ValueT>>,
Or<Logical<ValueT>, Logical<ValueT>>,
Not<Logical<ValueT>>);
} // namespace cinn::adt
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include "paddle/cinn/adt/adapter_tensor.h"
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
#include "paddle/cinn/adt/schedule_mesh.h"
#include "paddle/cinn/adt/tags.h"
#include "paddle/cinn/adt/tree.h"
namespace pir {
class Operation;
}
namespace cinn {
namespace adt {
// Offset = Int64
using Offset = std::int64_t;
class GlobalMemoryType final {
public:
bool operator==(const GlobalMemoryType& global_memory_type) const {
return this == &global_memory_type;
}
};
inline std::size_t GetHashValueImpl(const GlobalMemoryType&) {
static GlobalMemoryType global_memory_type;
return reinterpret_cast<std::size_t>(&global_memory_type);
}
class SharedMemoryType final {
public:
bool operator==(const SharedMemoryType& shared_memory_type) const {
return this == &shared_memory_type;
}
};
inline std::size_t GetHashValueImpl(const SharedMemoryType&) {
static SharedMemoryType shared_memory_type;
return reinterpret_cast<std::size_t>(&shared_memory_type);
}
// MemoryType = GlobalMemoryType | SharedMemoryType
DEFINE_ADT_UNION(MemoryType, GlobalMemoryType, SharedMemoryType);
OVERLOAD_OPERATOR_EQ_NE(MemoryType, UnionEqual);
OVERRIDE_UNION_GET_HASH_VALUE(MemoryType);
// TempStorage = (Name, Offset, MemoryType)
class TempStorage final : public Tuple<Name, Offset, MemoryType> {
public:
using Tuple<Name, Offset, MemoryType>::Tuple;
};
OVERLOAD_OPERATOR_EQ_NE(TempStorage, TupleEqual);
inline std::size_t GetHashValueImpl(const TempStorage& temp_storage) {
const auto& [var_name, offset, memory_type] = temp_storage.tuple();
std::size_t hash_value = std::hash<std::string>()(var_name);
hash_value = hash_combine(hash_value, offset);
hash_value = hash_combine(hash_value, GetHashValue(memory_type));
return hash_value;
}
// SSAShadowTensor = (tSSAShadow Name, adapter::Tensor)
class SSAShadowTensor final : public Tuple<tSSAShadow<Name>, adapter::Tensor> {
public:
using Tuple<tSSAShadow<Name>, adapter::Tensor>::Tuple;
};
OVERLOAD_OPERATOR_EQ_NE(SSAShadowTensor, TupleEqual);
OVERRIDE_TAG_GET_HASH_VALUE(tSSAShadow<Name>);
inline std::size_t GetHashValueImpl(const SSAShadowTensor& shadow_tensor) {
const auto& [shadow_name, tensor] = shadow_tensor.tuple();
return hash_combine(GetHashValue(shadow_name), GetHashValueImpl(tensor));
}
// Tensor = adapter::Tensor | SSAShadowTensor | TempStorage
DEFINE_ADT_UNION(Tensor, adapter::Tensor, SSAShadowTensor, TempStorage);
OVERRIDE_UNION_GET_HASH_VALUE(Tensor);
OVERLOAD_OPERATOR_EQ_NE(Tensor, UnionEqual);
// Op = const pir::Operation*
// | tReduceInit<const pir::Operation*>
// | tReduceAcc<const pir::Operation*>
DEFINE_ADT_UNION(Op,
const ::pir::Operation*,
tReduceInit<const ::pir::Operation*>,
tReduceAcc<const ::pir::Operation*>);
using Arg = Tensor;
// OpStmt = (Op, In [Arg], Out [Arg])
class OpStmt final : public Tuple<Op, tIn<List<Arg>>, tOut<List<Arg>>> {
public:
using Tuple<Op, tIn<List<Arg>>, tOut<List<Arg>>>::Tuple;
bool operator==(const OpStmt& other) const {
return &this->tuple() == &other.tuple();
}
};
inline std::size_t GetHashValue(const OpStmt& op_stmt_node) {
return reinterpret_cast<std::size_t>(&op_stmt_node.tuple());
}
using LoopIterators = List<Iterator>;
// MapStmt T = ([Iterator], [T])
template <typename T>
class MapStmt final : public Tuple<LoopIterators, List<T>> {
public:
using value_type = LoopIterators;
using Tuple<LoopIterators, List<T>>::Tuple;
};
// Stmt = OpStmt | MapStmt Stmt
using Stmt = Tree<MapStmt, OpStmt>;
template <typename OutT, typename InT>
class Store final : public Tuple<OutT, InT> {
public:
using Tuple<OutT, InT>::Tuple;
};
template <typename T>
class Load final : public Tuple<T> {
public:
using Tuple<T>::Tuple;
};
// OpCall T = (Op, [T])
template <typename T>
class OpCall final : public Tuple<Op, List<T>> {
public:
using Tuple<Op, List<T>>::Tuple;
};
// OpExpr = Tree OpCall (Load Tensor)
using OpExpr = Tree<OpCall, Load<Tensor>>;
// OpExprStmt = Store Tensor OpExpr
using OpExprStmt = Store<Tensor, OpExpr>;
using InlineStmt = Tree<MapStmt, OpExprStmt>;
using TensorIndexExpr = Value;
using TensorIndexExpr4TensorT = std::function<TensorIndexExpr(const Tensor&)>;
using TensorIteratorExpr = Value;
using TensorIteratorExpr4TensorT =
std::function<List<TensorIteratorExpr>(const Tensor&)>;
using LoopDescriptor4LoopIteratorT =
std::function<LoopDescriptor(const Iterator&)>;
// AnchoredMapStmt = (MapStmt Stmt, ScheduleMesh, tAnchor Tensor,
// TensorIndexExpr4TensorT, TensorIteratorExpr4TensorT,
// LoopDescriptor4LoopIteratorT)
class AnchoredMapStmt final : public Tuple<MapStmt<Stmt>,
ScheduleMesh,
tAnchor<Tensor>,
TensorIndexExpr4TensorT,
TensorIteratorExpr4TensorT,
LoopDescriptor4LoopIteratorT> {
public:
using Tuple<MapStmt<Stmt>,
ScheduleMesh,
tAnchor<Tensor>,
TensorIndexExpr4TensorT,
TensorIteratorExpr4TensorT,
LoopDescriptor4LoopIteratorT>::Tuple;
TensorIndexExpr GetTensorIndexExpr(const Tensor& tensor) const {
const auto& TensorIndexExpr4Tensor = std::get<3>(tuple());
return TensorIndexExpr4Tensor(tensor);
}
};
// Kernel = ([AnchoredMapStmt], In [Tensor], Out [Tensor])
class Kernel final : public Tuple<List<AnchoredMapStmt>,
tIn<List<Tensor>>,
tOut<List<Tensor>>> {
public:
using Tuple<List<AnchoredMapStmt>, tIn<List<Tensor>>, tOut<List<Tensor>>>::
Tuple;
};
// MapExpr = Kernel;
using MapExpr = Kernel;
} // namespace adt
} // namespace cinn
namespace std {
template <>
struct hash<cinn::adt::Tensor> {
std::size_t operator()(const cinn::adt::Tensor& tensor) const {
return cinn::adt::GetHashValue(tensor);
}
};
template <>
struct hash<cinn::adt::OpStmt> {
std::size_t operator()(const cinn::adt::OpStmt& op_stmt_node) const {
return cinn::adt::GetHashValue(op_stmt_node);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iterator>
#include <unordered_map>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/m_ir.h"
#include "paddle/cinn/adt/naive_equation_function_constants_provider.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
#include "paddle/cinn/adt/partition_op_stmts.h"
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/adt/write_broadcast_disabled_bidirection_equation_generator.h"
namespace cinn::adt {
template <typename DoEachT>
void VisitEachOpStmt(const List<OpStmt>& op_stmts, const DoEachT& DoEach) {
for (const auto& op_stmt_node : *op_stmts) {
DoEach(op_stmt_node);
}
}
void CollectTensorIndexIterators(const TensorIndexExpr& tensor_index_expr,
std::unordered_set<Iterator>* ret);
void CollectTensorIndexIteratorsImpl(const Undefined& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
LOG(FATAL) << "Not Implemented";
}
void CollectTensorIndexIteratorsImpl(const Ok& ok,
std::unordered_set<Iterator>* ret) {
LOG(FATAL) << "Not Implemented";
}
void CollectTensorIndexIteratorsImpl(const Iterator& iterator,
std::unordered_set<Iterator>* ret) {
ret->emplace(iterator);
}
void CollectTensorIndexIteratorsImpl(const Constant& constant,
std::unordered_set<Iterator>* ret) {
// Do nothing
}
void CollectTensorIndexIteratorsImpl(const List<Value>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
for (const auto& value : *tensor_index_expr) {
CollectTensorIndexIterators(value, ret);
}
}
void CollectTensorIndexIteratorsImpl(
const IndexDotValue<Value, Constant>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(tensor_index_expr.GetIteratorsValue(), ret);
}
void CollectTensorIndexIteratorsImpl(
const IndexUnDotValue<Value, Constant>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(tensor_index_expr.GetIndexValue(), ret);
}
void CollectTensorIndexIteratorsImpl(
const ListGetItem<Value, Constant>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(tensor_index_expr.GetList(), ret);
}
void CollectTensorIndexIteratorsImpl(
const BroadcastedIterator<Value, Constant>& broadcasted_iterator,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(broadcasted_iterator.GetArg0(), ret);
}
void CollectTensorIndexIteratorsImpl(const PtrGetItem<Value>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(tensor_index_expr.GetArg1(), ret);
}
void CollectTensorIndexIterators(const TensorIndexExpr& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
std::visit(
[&](const auto& impl) { CollectTensorIndexIteratorsImpl(impl, ret); },
tensor_index_expr.variant());
}
std::unordered_set<Iterator> GetTensorIndexIterators(
const TensorIndexExpr& tensor_index_expr) {
std::unordered_set<Iterator> ret{};
CollectTensorIndexIterators(tensor_index_expr, &ret);
return ret;
}
LoopIterators GetSortedSdIterators(
const std::unordered_set<Iterator>& tensor_index_loop_iters,
const LoopIterators& loop_iters) {
LoopIterators ret{};
for (const auto& loop_iter : *loop_iters) {
if (tensor_index_loop_iters.count(loop_iter) > 0) {
ret->emplace_back(loop_iter);
}
}
return ret;
}
LoopIterators GetAnchorTensorLoopIterators(
const Tensor& tensor,
const LoopIterators& loop_iters,
const std::function<TensorIndexExpr(const Tensor&)>&
TensorIndexExpr4Tensor) {
const auto& tensor_index_loop_iters =
GetTensorIndexIterators(TensorIndexExpr4Tensor(tensor));
return GetSortedSdIterators(tensor_index_loop_iters, loop_iters);
}
namespace {
Tensor GetTensorImpl(const OpStmt& op_stmt, const Undefined& undefined) {
LOG(FATAL) << "position not found";
}
Tensor GetTensorImpl(const OpStmt& op_stmt, const tIn<std::size_t>& pos) {
const auto& [op, in_args, out_args] = op_stmt.tuple();
return in_args.value()->at(pos.value());
}
Tensor GetTensorImpl(const OpStmt& op_stmt, const tOut<std::size_t>& pos) {
const auto& [op, in_args, out_args] = op_stmt.tuple();
return out_args.value()->at(pos.value());
}
Tensor GetTensor(const config::NaiveOpEquationContext& ctx,
const OpStmt& op_stmt,
const Index& index) {
const auto& op_arg_pos = ctx.GetOpArgPos(index);
return std::visit(
[&](const auto& impl) { return GetTensorImpl(op_stmt, impl); },
op_arg_pos.variant());
}
Tensor GetAnchorTensor(const AnchorGroup& anchor_group) {
const auto& ctx = *anchor_group.EquationCtx4OpStmt(anchor_group.op_stmt);
return GetTensor(ctx, anchor_group.op_stmt, anchor_group.anchor_index);
}
std::unordered_map<Index, LoopIterators> GenerateAnchorIndex2LoopIterators(
const std::vector<AnchorGroup>& partitioned_anchor_groups,
const std::function<TensorIndexExpr(const Tensor&)>& TensorIndexExpr4Tensor,
const LoopIterators& loop_iters) {
std::unordered_map<Index, LoopIterators> anchor_index2loop_iters{};
for (const auto& anchor_group : partitioned_anchor_groups) {
const auto& anchor_index = anchor_group.anchor_index;
const auto& anchor_tensor = GetAnchorTensor(anchor_group);
const auto& anchor_loop_iters = GetAnchorTensorLoopIterators(
anchor_tensor, loop_iters, TensorIndexExpr4Tensor);
CHECK(anchor_index2loop_iters.emplace(anchor_index, anchor_loop_iters)
.second);
}
return anchor_index2loop_iters;
}
} // namespace
MapIrList ConvertAnchorGroups2MapIrList(
const std::vector<AnchorGroup>& partitioned_anchor_groups,
const std::function<TensorIndexExpr(const Tensor&)>& TensorIndexExpr4Tensor,
const LoopIterators& loop_iters) {
const auto& anchor_index2loop_iters = GenerateAnchorIndex2LoopIterators(
partitioned_anchor_groups, TensorIndexExpr4Tensor, loop_iters);
MapIrList ret{};
for (const auto& anchor_group : partitioned_anchor_groups) {
const auto& anchor_index = anchor_group.anchor_index;
const auto& anchor_loop_iters = anchor_index2loop_iters.at(anchor_index);
ret->emplace_back(MapIr{anchor_group.op_stmts, anchor_loop_iters});
}
return ret;
}
MapIrList GenerateMapIrListForLoopFuse(
const List<OpStmt>& op_stmts,
const LoopIterators& loop_iters,
const std::function<TensorIndexExpr(const Tensor&)>&
TensorIndexExpr4Tensor) {
const auto& EquationCtx4OpStmt =
config::GenerateContext4LocalOpStmt(op_stmts);
auto direction_equation_generator =
std::make_shared<WriteBroadcastDisabledBidirectionEquationGenerator>(
op_stmts, EquationCtx4OpStmt);
const auto& partitioned_anchor_groups = PartitionOpStmts(
EquationCtx4OpStmt, op_stmts, direction_equation_generator);
return ConvertAnchorGroups2MapIrList(
partitioned_anchor_groups, TensorIndexExpr4Tensor, loop_iters);
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include "paddle/cinn/adt/m_expr.h"
namespace cinn::adt {
using LoopIterators = List<Iterator>;
}
namespace cinn::adt {
class MapIr final {
public:
MapIr(const List<OpStmt>& op_stmts, const LoopIterators& loop_iterators)
: op_stmts_{op_stmts}, loop_iterators_(loop_iterators) {}
MapIr(const MapIr&) = default;
MapIr(MapIr&&) = default;
const List<OpStmt>& op_stmts() const { return op_stmts_; }
const LoopIterators& loop_iterators() const { return loop_iterators_; }
private:
List<OpStmt> op_stmts_;
LoopIterators loop_iterators_;
};
using MapIrList = List<MapIr>;
MapIrList GenerateMapIrListForLoopFuse(
const List<OpStmt>& op_stmts,
const LoopIterators& loop_iters,
const std::function<TensorIndexExpr(const Tensor&)>&
TensorIndexExpr4Tensor);
void CollectTensorIndexIterators(const TensorIndexExpr& tensor_index_expr,
std::unordered_set<Iterator>* ret);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/pir/core/operation.h"
namespace cinn::adt {
class MapExprCtx final {
public:
using Node2LoweredFuncs =
std::unordered_map<::pir::Operation*, std::vector<ir::LoweredFunc>>;
MapExprCtx(const MapExprCtx&) = delete;
MapExprCtx(MapExprCtx&&) = delete;
explicit MapExprCtx(const MapExpr& map_expr) : map_expr_(map_expr) {}
const MapExpr& map_expr() const { return map_expr_; }
void UpdateOpLoweredFuncKey(
::pir::Operation* node,
const std::vector<ir::LoweredFunc>& lowered_funcs) {
Node2LoweredFuncs* map = &node2lowered_funcs_;
CHECK(map->emplace(node, ir::ir_utils::IRCopy(lowered_funcs)).second);
}
const Node2LoweredFuncs& node2lowered_funcs() const {
return node2lowered_funcs_;
}
private:
const MapExpr map_expr_;
Node2LoweredFuncs node2lowered_funcs_;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
namespace cinn::adt {
template <typename SumT, typename T>
struct MatchTrait;
namespace detail {
template <bool is_expr>
struct ExprMatchTrait;
template <>
struct ExprMatchTrait<true> final {
template <typename ExprT, typename T>
struct match_trait_type {
static_assert(std::is_same<ExprT, T>::value, "");
static constexpr int is_template = false;
};
};
template <>
struct ExprMatchTrait<false> final {
template <typename ExprT, typename T>
using match_trait_type = MatchTrait<ExprT, T>;
};
template <bool is_leaf, typename ExprT>
struct DoMatch;
template <typename ExprT>
struct Match final {
template <typename source_pattern_type>
static bool Call(const ExprT& pattern_expr) {
static constexpr bool is_expr =
std::is_same<ExprT, source_pattern_type>::value;
static constexpr bool is_template = ExprMatchTrait<is_expr>::
template match_trait_type<ExprT, source_pattern_type>::is_template;
static constexpr bool is_leaf = is_expr || !is_template;
return DoMatch<is_leaf, ExprT>::template Call<source_pattern_type>(
pattern_expr);
}
};
template <typename ExprT>
struct DoMatch</*is_leaf*/ true, ExprT> final {
template <typename source_pattern_type>
static bool Call(const ExprT& pattern_expr) {
if constexpr (std::is_same<std::decay_t<ExprT>,
source_pattern_type>::value) {
return true;
}
return pattern_expr.Visit([](auto&& impl) {
if constexpr (std::is_same<std::decay_t<decltype(impl)>,
source_pattern_type>::value) {
return true;
} else {
return false;
}
});
}
};
template <typename ExprT>
struct DoMatch</*is_leaf*/ false, ExprT> final {
template <typename source_pattern_type>
static bool Call(const ExprT& pattern_expr) {
return pattern_expr.Visit([](auto&& impl) {
using pattern_type =
typename MatchTrait<ExprT, source_pattern_type>::base_type;
if constexpr (std::is_same<std::decay_t<decltype(impl)>,
pattern_type>::value) {
return MatchTrait<ExprT, source_pattern_type>::template MatchChildren<
Match>(impl);
} else {
return false;
}
});
}
};
} // namespace detail
template <typename SourcePatternT, typename ExprT>
bool Match(const ExprT& expr) {
return detail::Match<ExprT>::template Call<SourcePatternT>(expr);
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/naive_bidirection_equation_generator.h"
#include "paddle/cinn/adt/equation_graph.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/naive_equation_function_constants_provider.h"
namespace cinn::adt {
namespace {
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
template <
typename DoEachT /*: void(&)(std::size_t, OpStmt, OpEquationContext)*/>
void VisitEachOpStmtAndEquationCtx(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const DoEachT& DoEach) {
for (std::size_t i = 0; i < op_stmts->size(); ++i) {
const auto& ctx = EquationCtx4OpStmt(op_stmts->at(i));
DoEach(i, op_stmts->at(i), ctx);
}
}
List<Index> MakeArgIndexes(std::size_t num_args) {
List<Index> ret{};
for (std::size_t i = 0; i < num_args; ++i) {
Index index{UniqueId::New()};
ret->emplace_back(index);
}
return ret;
}
OpArgIndexes<std::optional<Index>> MakeOutMsgOpArgIndexes(
const List<std::optional<Index>>& opt_out_msg_in_indexes,
const List<std::optional<Index>>& opt_out_msg_out_indexes) {
List<Index> out_msg_in_indexes{};
for (const auto& out_msg_in_index : *opt_out_msg_in_indexes) {
CHECK(out_msg_in_index.has_value());
out_msg_in_indexes->emplace_back(out_msg_in_index.value());
}
return OpArgIndexes<std::optional<Index>>{out_msg_in_indexes,
opt_out_msg_out_indexes};
}
OpArgIndexes<Index> MakeInMsgOpArgIndexes(
const List<Index>& in_msg_in_indexes,
const List<Index>& in_msg_out_indexes) {
return OpArgIndexes<Index>{in_msg_in_indexes, in_msg_out_indexes};
}
template <typename DoEachT>
void VisitEachInMsgOutMsgPair(const List<Index>& in_msg_indexes,
const List<Index>& out_msg_indexes,
const DoEachT& DoEach) {
CHECK_EQ(in_msg_indexes->size(), out_msg_indexes->size());
for (std::size_t i = 0; i < in_msg_indexes->size(); ++i) {
DoEach(in_msg_indexes->at(i), out_msg_indexes->at(i));
}
}
List<std::optional<Index>> GetOutMsgIndexes(
const List<Index>& in_indexes,
const NaiveBidirectionEquationGenerator& generator) {
List<std::optional<Index>> ret{};
for (const auto& index : *in_indexes) {
ret->emplace_back(generator.OutMsgIndex4InMsgIndex(index));
}
return ret;
}
using InMsg2OutMsgT = InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>;
} // namespace
void NaiveBidirectionEquationGenerator::InitInMsgIndex2OutMsgIndex() {
const auto& InitEachOpInMsgIndex2OutMsgIndex =
[&](const std::shared_ptr<config::NaiveOpEquationContext>& ctx,
bool is_output) {
List<Index> in_msg_indexes =
is_output ? ctx->out_indexes() : ctx->in_indexes();
std::size_t out_msg_index_size = is_output
? ctx->GetOutTensorsRanks().size()
: ctx->GetInTensorsRanks().size();
List<Index> out_msg_indexes = MakeArgIndexes(out_msg_index_size);
VisitEachInMsgOutMsgPair(
in_msg_indexes,
out_msg_indexes,
[&](const Index& in_index, const Index& out_index) {
CHECK(
this->in_msg_index2out_msg_index_.emplace(in_index, out_index)
.second);
});
};
VisitEachOpStmtAndEquationCtx(
this->op_stmts_,
this->EquationCtx4OpStmt_,
[&](std::size_t idx,
const OpStmt& op_stmt,
const std::shared_ptr<config::NaiveOpEquationContext>& ctx) {
InitEachOpInMsgIndex2OutMsgIndex(ctx, /*is_output=*/false);
InitEachOpInMsgIndex2OutMsgIndex(ctx, /*is_output=*/true);
});
}
void NaiveBidirectionEquationGenerator::InitEquations() {
VisitEachOpStmtAndEquationCtx(
this->op_stmts_,
this->EquationCtx4OpStmt_,
[&](std::size_t idx,
const OpStmt& op_stmt,
const std::shared_ptr<config::NaiveOpEquationContext>& ctx) {
List<Index> in_msg_in_indexes = ctx->in_indexes();
List<Index> in_msg_out_indexes = ctx->out_indexes();
List<std::optional<Index>> out_msg_in_indexes =
GetOutMsgIndexes(in_msg_in_indexes, *this);
List<std::optional<Index>> out_msg_out_indexes =
GetOutMsgIndexes(in_msg_out_indexes, *this);
Equation equation = InMsg2OutMsgT{
ctx->fake_op_placeholder(),
MakeOutMsgOpArgIndexes(out_msg_in_indexes, out_msg_out_indexes),
MakeInMsgOpArgIndexes(in_msg_in_indexes, in_msg_out_indexes)};
this->fake_op_placeholders_->emplace_back(ctx->fake_op_placeholder());
this->equations_->emplace_back(equation);
});
}
std::function<const OpStmt*(const FakeOpPlaceHolder&)>
NaiveBidirectionEquationGenerator::MakeGetterOpStmt4OpPlaceHolder() const {
using FakeOpPlaceHolder2OpStmt =
std::unordered_map<FakeOpPlaceHolder, OpStmt>;
const auto& fake_op_placeholder2op_stmt =
std::make_shared<FakeOpPlaceHolder2OpStmt>();
for (std::size_t i = 0; i < fake_op_placeholders_->size(); ++i) {
CHECK(fake_op_placeholder2op_stmt
->emplace(fake_op_placeholders_->at(i), op_stmts_->at(i))
.second);
}
return [fake_op_placeholder2op_stmt](
const FakeOpPlaceHolder& fake_op_placeholder) {
return &fake_op_placeholder2op_stmt->at(fake_op_placeholder);
};
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/cinn/adt/direction_equation_generator.h"
#include "paddle/cinn/adt/equation_function.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
namespace cinn::adt {
namespace config {
class NaiveOpEquationContext;
}
class NaiveBidirectionEquationGenerator : public DirectionEquationGenerator {
public:
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
NaiveBidirectionEquationGenerator(const NaiveBidirectionEquationGenerator&) =
delete;
NaiveBidirectionEquationGenerator(NaiveBidirectionEquationGenerator&&) =
delete;
NaiveBidirectionEquationGenerator(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt)
: op_stmts_(op_stmts), EquationCtx4OpStmt_(EquationCtx4OpStmt) {
Init();
}
Equations GetDirectionEquations() const override { return equations_; }
std::function<const OpStmt*(const FakeOpPlaceHolder&)>
MakeGetterOpStmt4OpPlaceHolder() const override;
std::optional<Index> OutMsgIndex4InMsgIndex(
const Index& index) const override {
const auto& iter = in_msg_index2out_msg_index_.find(index);
if (iter == in_msg_index2out_msg_index_.end()) {
return std::nullopt;
} else {
return iter->second;
}
}
const List<OpStmt>& op_stmts() const { return op_stmts_; }
const EquationCtx4OpStmtT& EquationCtx4OpStmt() const {
return EquationCtx4OpStmt_;
}
const Equations& equations() const { return equations_; }
private:
void InitInMsgIndex2OutMsgIndex();
void InitEquations();
void Init() {
InitInMsgIndex2OutMsgIndex();
InitEquations();
}
protected:
List<OpStmt> op_stmts_;
EquationCtx4OpStmtT EquationCtx4OpStmt_;
Equations equations_;
List<FakeOpPlaceHolder> fake_op_placeholders_;
std::unordered_map<Index, Index> in_msg_index2out_msg_index_;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/cinn/adt/equation_function_constants_provider.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
namespace cinn::adt {
class NaiveEquationFunctionConstantsProvider final
: public EquationFunctionConstantsProvider {
public:
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
NaiveEquationFunctionConstantsProvider(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt) {
Init(op_stmts, EquationCtx4OpStmt);
}
NaiveEquationFunctionConstantsProvider(
const NaiveEquationFunctionConstantsProvider&) = delete;
NaiveEquationFunctionConstantsProvider(
NaiveEquationFunctionConstantsProvider&&) = delete;
Constant GetDimSize(const Dim& dim) const override {
const auto& iter = dim2constant_.find(dim);
CHECK(iter != dim2constant_.end());
return iter->second;
}
bool AddDim(const Dim& dim, const Constant& dim_value) override {
return dim2constant_.emplace(dim, dim_value).second;
}
private:
void Init(const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt) {
for (const auto& op_stmt : *op_stmts) {
const auto& ctx = EquationCtx4OpStmt(op_stmt);
ctx->VisitEachArgPos(
[&](bool is_out, std::size_t arg_idx, std::size_t axis) {
const Dim& dim = ctx->GetDim(is_out, arg_idx, axis);
const Constant& constant = ctx->GetDimSize(is_out, arg_idx, axis);
CHECK(dim2constant_.emplace(dim, constant).second);
});
}
}
std::unordered_map<Dim, const Constant> dim2constant_;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/cinn/adt/adapter_tensor.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
#include "paddle/cinn/adt/op_arg_pos.h"
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/utils/type_defs.h"
#include "glog/logging.h"
namespace cinn::adt::config {
void NaiveOpEquationContext::Print() const {
VLOG(1) << "Equations : \n" << ToTxtString(equations());
}
std::vector<std::uint64_t> MakeTensorRanks(const List<Arg>& arg_lists) {
std::vector<std::uint64_t> ret;
for (const auto& arg : *arg_lists) {
CHECK(arg.Has<adapter::Tensor>());
ret.push_back(arg.Get<adapter::Tensor>().GetRank());
}
return ret;
}
void GenerateOpEquationsImpl(const ::pir::Operation* op_node,
const OpStmt& op_stmt,
config::NaiveOpEquationContext* ctx) {
const auto& [_, inputs, outputs] = op_stmt.tuple();
using GenerateEquationFunc =
std::function<void(config::OpEquationContext * ctx)>;
const auto& generate_equations =
hlir::framework::Operator::GetAttrs<GenerateEquationFunc>(
"generate_equations");
const hlir::framework::Operator* cinn_op = hlir::framework::Operator::Get(
hlir::framework::pir::CompatibleInfo::OpName(*op_node));
CHECK(generate_equations.Find(cinn_op));
generate_equations[cinn_op](ctx);
}
using GetArgStaticDimT = std::function<std::optional<std::int64_t>(
std::size_t tensor_idx, std::size_t dim_idx)>;
GetArgStaticDimT MakeGetArgStaticDimT(const List<Tensor>& tensors) {
return [=](std::size_t tensor_idx,
std::size_t dim_idx) -> std::optional<std::int64_t> {
if (tensor_idx >= tensors->size()) {
return std::nullopt;
}
CHECK(tensors->at(tensor_idx).Has<adapter::Tensor>());
const std::vector<int32_t> tensor_shape =
tensors->at(tensor_idx).Get<adapter::Tensor>().GetShape();
if (dim_idx >= tensor_shape.size()) {
return std::nullopt;
}
return tensor_shape.at(dim_idx);
};
}
void GenerateOpEquationsImpl(const tReduceAcc<const ::pir::Operation*>& op_node,
const OpStmt& op_stmt,
config::NaiveOpEquationContext* ctx) {
GenerateOpEquationsImpl(op_node.value(), op_stmt, ctx);
}
void GenerateOpEquationsImpl(
const tReduceInit<const ::pir::Operation*>& op_node,
const OpStmt& op_stmt,
config::NaiveOpEquationContext* ctx) {
// Do nothing
}
void GenerateOpEquations(const OpStmt& op_stmt,
config::NaiveOpEquationContext* ctx) {
const auto& [op, inputs, outputs] = op_stmt.tuple();
return std::visit(
[&](const auto& impl) {
return GenerateOpEquationsImpl(impl, op_stmt, ctx);
},
op.variant());
}
cinn::utils::AttributeMap GetOpAttrImpl(const ::pir::Operation* op_node) {
return hlir::framework::pir::CompatibleInfo::ConvertAttributes(*op_node);
}
cinn::utils::AttributeMap GetOpAttrImpl(
const tReduceInit<const ::pir::Operation*>&) {
return cinn::utils::AttributeMap{};
}
cinn::utils::AttributeMap GetOpAttrImpl(
const tReduceAcc<const ::pir::Operation*>& op_node) {
return GetOpAttrImpl(op_node.value());
}
cinn::utils::AttributeMap GetOpAttr(const OpStmt& op_stmt) {
const auto& [op_node, inputs, outputs] = op_stmt.tuple();
return std::visit([&](const auto& impl) { return GetOpAttrImpl(impl); },
op_node.variant());
}
std::shared_ptr<config::NaiveOpEquationContext> MakeContextAndGenerateEquations(
const OpStmt& op_stmt) {
const auto& [op, inputs, outputs] = op_stmt.tuple();
const auto& ctx = std::make_shared<config::NaiveOpEquationContext>(
MakeTensorRanks(inputs.value()),
MakeTensorRanks(outputs.value()),
MakeGetArgStaticDimT(inputs.value()),
MakeGetArgStaticDimT(outputs.value()),
GetOpAttr(op_stmt));
GenerateOpEquations(op_stmt, ctx.get());
return ctx;
}
std::function<std::shared_ptr<config::NaiveOpEquationContext>(const OpStmt&)>
GenerateContext4LocalOpStmt(const List<OpStmt>& op_stmts) {
using OpStmt2EquationContext =
std::unordered_map<OpStmt,
std::shared_ptr<config::NaiveOpEquationContext>>;
const auto& op_stmt2equation_ctx = std::make_shared<OpStmt2EquationContext>();
for (const auto& op_stmt : *op_stmts) {
const auto& ctx = MakeContextAndGenerateEquations(op_stmt);
CHECK(op_stmt2equation_ctx->emplace(op_stmt, ctx).second);
}
return [op_stmt2equation_ctx](const auto& op_stmt) {
return op_stmt2equation_ctx->at(op_stmt);
};
}
template <typename T0, typename T1>
struct CompLogicalExpr {
template <typename CompareT>
static bool Call(const CompareT& Compare, const T0&, const T1&) {
LOG(FATAL) << "Unimplemented";
}
};
template <>
struct CompLogicalExpr<std::int64_t, std::int64_t> {
template <typename CompareT>
static bool Call(const CompareT& Compare,
std::int64_t lhs,
std::int64_t rhs) {
return Compare(lhs, rhs);
}
};
template <typename CompareT>
bool CalculateLogicalExprImpl(
const std::tuple<EquationStaticValue, EquationStaticValue>& tuple,
const CompareT& Compare) {
const auto& [lhs, rhs] = tuple;
return std::visit(
[&](auto&& lhs, auto&& rhs) {
return CompLogicalExpr<
std::decay_t<decltype(lhs)>,
std::decay_t<decltype(rhs)>>::template Call<CompareT>(Compare,
lhs,
rhs);
},
lhs.variant(),
rhs.variant());
}
#define MAKE_COMPARE_LAMBDA(op) \
[](const std::int64_t lhs, const std::int64_t rhs) { return lhs op rhs; }
bool ParseLogicalExpr(const EquationStaticLogical& expr);
bool ParseLogicalExprImpl(
const EQ<EquationStaticValue, EquationStaticValue>& expr) {
return CalculateLogicalExprImpl(expr.tuple(), MAKE_COMPARE_LAMBDA(==));
}
bool ParseLogicalExprImpl(
const LT<EquationStaticValue, EquationStaticValue>& expr) {
return CalculateLogicalExprImpl(
expr.tuple(),
[](const std::int64_t lhs, const std::int64_t rhs) { return lhs < rhs; });
}
bool ParseLogicalExprImpl(
const GT<EquationStaticValue, EquationStaticValue>& expr) {
return CalculateLogicalExprImpl(
expr.tuple(),
[](const std::int64_t lhs, const std::int64_t rhs) { return lhs > rhs; });
}
bool ParseLogicalExprImpl(
const NE<EquationStaticValue, EquationStaticValue>& expr) {
return CalculateLogicalExprImpl(expr.tuple(), MAKE_COMPARE_LAMBDA(!=));
}
bool ParseLogicalExprImpl(
const GE<EquationStaticValue, EquationStaticValue>& expr) {
return CalculateLogicalExprImpl(expr.tuple(), MAKE_COMPARE_LAMBDA(>=));
}
bool ParseLogicalExprImpl(
const LE<EquationStaticValue, EquationStaticValue>& expr) {
return CalculateLogicalExprImpl(expr.tuple(), MAKE_COMPARE_LAMBDA(<=));
}
bool ParseLogicalExprImpl(const And<Logical<EquationStaticValue>,
Logical<EquationStaticValue>>& expr) {
const auto& [lhs, rhs] = expr.tuple();
return ParseLogicalExpr(rhs) && ParseLogicalExpr(rhs);
}
bool ParseLogicalExprImpl(const Or<Logical<EquationStaticValue>,
Logical<EquationStaticValue>>& expr) {
const auto& [lhs, rhs] = expr.tuple();
return ParseLogicalExpr(rhs) || ParseLogicalExpr(rhs);
}
bool ParseLogicalExprImpl(const Not<Logical<EquationStaticValue>>& expr) {
const auto& [unpacked_expr] = expr.tuple();
return !ParseLogicalExpr(unpacked_expr);
}
bool ParseLogicalExpr(const EquationStaticLogical& expr) {
return std::visit(
[&](const auto& impl) { return ParseLogicalExprImpl(impl); },
expr.variant());
}
std::optional<std::int64_t> GetArgDimSizeImpl(
const tIn<ArgDimPosDescriptor>& in_arg_dim_pos,
const GetArgStaticDimT& GetInDim,
const GetArgStaticDimT& GetOutDim) {
return GetInDim(in_arg_dim_pos.value().tensor_idx,
in_arg_dim_pos.value().dim_idx);
}
std::optional<std::int64_t> GetArgDimSizeImpl(
const tOut<ArgDimPosDescriptor>& out_arg_dim_pos,
const GetArgStaticDimT& GetInDim,
const GetArgStaticDimT& GetOutDim) {
return GetOutDim(out_arg_dim_pos.value().tensor_idx,
out_arg_dim_pos.value().dim_idx);
}
std::optional<std::int64_t> GetArgDimSizeImpl(
const Undefined&,
const GetArgStaticDimT& GetInDim,
const GetArgStaticDimT& GetOutDim) {
LOG(FATAL) << "position not found";
}
std::optional<std::int64_t> GetArgDimSize(const OpArgDimPos& arg_dim_pos,
const GetArgStaticDimT& GetInDim,
const GetArgStaticDimT& GetOutDim) {
return std::visit(
[&](const auto& impl) {
return GetArgDimSizeImpl(impl, GetInDim, GetOutDim);
},
arg_dim_pos.variant());
}
std::int64_t NaiveOpEquationContext::GetDimSize(const Dim& dim) const {
const auto& arg_dim_pos = GetArgDimPosDescriptor(dim);
const auto& option_dim_size =
GetArgDimSize(arg_dim_pos, GetInDim_, GetOutDim_);
if (!option_dim_size.has_value()) {
LOG(FATAL) << "Dim not found";
}
return option_dim_size.value();
}
} // namespace cinn::adt::config
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <optional>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/logical.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/op_arg_pos.h"
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
namespace cinn::adt::config {
class NaiveOpEquationContext final : public OpEquationContext {
public:
NaiveOpEquationContext(const NaiveOpEquationContext&) = delete;
NaiveOpEquationContext(NaiveOpEquationContext&&) = delete;
// TODO(Hongyu Jia): std::optional<std::int64_t> -> Constant
using GetArgStaticDimT = std::function<std::optional<std::int64_t>(
std::size_t tensor_idx, std::size_t dim_idx)>;
explicit NaiveOpEquationContext(
const std::vector<std::uint64_t>& in_tensors_ranks,
const std::vector<std::uint64_t>& out_tensors_ranks,
GetArgStaticDimT GetInDim,
GetArgStaticDimT GetOutDim,
cinn::utils::AttributeMap attr_map_type)
: in_tensors_ranks_(in_tensors_ranks),
out_tensors_ranks_(out_tensors_ranks),
GetInDim_(GetInDim),
GetOutDim_(GetOutDim),
equations_{},
attr_map_type_(attr_map_type),
fake_op_placeholder_{UniqueId::New()} {
Init<Iterator>(&in_iterator_tuples_, in_tensors_ranks);
Init<Iterator>(&out_iterator_tuples_, out_tensors_ranks);
Init<Dim>(&in_dim_tuples_, in_tensors_ranks);
Init<Dim>(&out_dim_tuples_, out_tensors_ranks);
in_indexes_ = MakeArgIndexes(in_tensors_ranks.size());
out_indexes_ = MakeArgIndexes(out_tensors_ranks.size());
GenerateDots();
}
~NaiveOpEquationContext() = default;
const std::vector<std::uint64_t>& GetInTensorsRanks() const override {
return in_tensors_ranks_;
}
const std::vector<std::uint64_t>& GetOutTensorsRanks() const override {
return out_tensors_ranks_;
}
void Equal(const Iterator& lhs, const Iterator& rhs) override {
this->Equal<Iterator>(lhs, rhs);
}
void Equal(const Index& lhs, const Index& rhs) override {
this->Equal<Index>(lhs, rhs);
}
void Equal(const IteratorTuple& lhs, const IteratorTuple& rhs) override {
CHECK(lhs->size() == rhs->size());
for (std::size_t i = 0; i < lhs->size(); ++i) {
this->Equal(lhs->at(i), rhs->at(i));
}
}
Iterator GetBroadcastedInputIterator(const Iterator& out_tensor_iterator,
const Dim& dim) override {
Iterator input_tensor_iterator{UniqueId::New()};
using Function = GetBroadcastedIterator<Dim, tOut<Iterator>, tIn<Iterator>>;
equations_->emplace_back(
Function{dim, input_tensor_iterator, out_tensor_iterator});
return input_tensor_iterator;
}
Iterator MakeConstantIterator(std::size_t constant,
Equations* equations) const {
using ConstF = ConstantFunction<tOut<Iterator>, tIn<Index>>;
Iterator const_iter{UniqueId::New()};
VisitEachTensorIndex([&](const auto& in_msg_index) {
(*equations)->emplace_back(ConstF{const_iter, in_msg_index, constant});
});
return const_iter;
}
const IteratorTuple& GetInIteratorTuple(
std::size_t input_idx) const override {
return in_iterator_tuples_.at(input_idx);
}
const IteratorTuple& GetOutIteratorTuple(
std::size_t output_idx) const override {
return out_iterator_tuples_.at(output_idx);
}
const Index& GetInIndex(std::size_t input_idx) const override {
return in_indexes_->at(input_idx);
}
const Index& GetOutIndex(std::size_t output_idx) const override {
return out_indexes_->at(output_idx);
}
const DimTuple& GetInDimTuple(std::size_t input_idx) const override {
return in_dim_tuples_.at(input_idx);
}
const DimTuple& GetOutDimTuple(std::size_t output_idx) const override {
return out_dim_tuples_.at(output_idx);
}
const List<Index>& in_indexes() const { return in_indexes_; }
const List<Index>& out_indexes() const { return out_indexes_; }
const Equations& equations() const { return equations_; }
void AddEquations(const Equations& equations) {
for (const auto& equation : *equations) {
equations_->emplace_back(equation);
}
}
const FakeOpPlaceHolder& fake_op_placeholder() const {
return fake_op_placeholder_;
}
template <typename DoEachT>
void VisitEachTensorIndex(const DoEachT& DoEach) const {
VisitEachInputTensorIndex(DoEach);
VisitEachOutputTensorIndex(DoEach);
}
template <typename DoEachT>
void VisitEachInputTensorIndex(const DoEachT& DoEach) const {
for (const auto& in_index : *in_indexes_) {
DoEach(in_index);
}
}
template <typename DoEachT>
void VisitEachOutputTensorIndex(const DoEachT& DoEach) const {
for (const auto& out_index : *out_indexes_) {
DoEach(out_index);
}
}
template <typename DoEachT>
void VisitEachEquation(const DoEachT& DoEach) const {
for (const auto& equation : *equations_) {
DoEach(equation);
}
}
template <typename DoEachT>
void VisitEachArgPos(const DoEachT& DoEach) const {
for (std::size_t arg_idx = 0; arg_idx < in_tensors_ranks_.size();
++arg_idx) {
for (std::size_t axis = 0; axis < in_tensors_ranks_.at(arg_idx); ++axis) {
DoEach(/*is_out*/ false, arg_idx, axis);
}
}
for (std::size_t arg_idx = 0; arg_idx < out_tensors_ranks_.size();
++arg_idx) {
for (std::size_t axis = 0; axis < out_tensors_ranks_.at(arg_idx);
++axis) {
DoEach(/*is_out*/ true, arg_idx, axis);
}
}
}
OpArgPos GetOpArgPos(const Index& index) const {
const auto& input_pos = FindPos(in_indexes_, index);
if (input_pos.has_value()) {
return tIn<std::size_t>{input_pos.value()};
}
const auto& output_pos = FindPos(out_indexes_, index);
if (output_pos.has_value()) {
return tOut<std::size_t>{output_pos.value()};
}
return Undefined{};
}
std::int64_t GetDimSize(const Dim& dim) const;
Dim GetDim(bool is_out, std::size_t arg_idx, std::size_t axis) const {
if (is_out) {
return out_dim_tuples_.at(arg_idx)->at(axis);
} else {
return in_dim_tuples_.at(arg_idx)->at(axis);
}
}
Constant GetDimSize(bool is_out,
std::size_t arg_idx,
std::size_t axis) const {
const auto* Get = (is_out ? &GetOutDim_ : &GetInDim_);
const auto& opt_dim = (*Get)(arg_idx, axis);
CHECK(opt_dim.has_value());
return opt_dim.value();
}
OpArgDimPos GetArgDimPosDescriptor(const Dim& dim) const {
const auto& input_pos = FindArgDimPos(in_dim_tuples_, dim);
if (input_pos.has_value()) {
return tIn<ArgDimPosDescriptor>{input_pos.value()};
}
const auto& output_pos = FindArgDimPos(out_dim_tuples_, dim);
if (output_pos.has_value()) {
return tOut<ArgDimPosDescriptor>{output_pos.value()};
}
return Undefined{};
}
void Print() const;
private:
template <typename value_type, typename ContainerT>
void Init(ContainerT* vec, const std::vector<std::uint64_t>& tensors_ranks) {
for (std::size_t i = 0; i < tensors_ranks.size(); ++i) {
vec->push_back(typename ContainerT::value_type{});
for (std::size_t j = 0; j < tensors_ranks.at(i); ++j) {
vec->at(i)->emplace_back(value_type{UniqueId::New()});
}
}
}
Index IndexDot(const IteratorTuple& iterator_tuple,
const DimTuple& dim_tuple) {
CHECK(iterator_tuple->size() == dim_tuple->size());
Index index{UniqueId::New()};
equations_->emplace_back(
adt::IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>{
dim_tuple, index, iterator_tuple});
equations_->emplace_back(
adt::IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>{
dim_tuple, iterator_tuple, index});
return index;
}
static List<Index> MakeArgIndexes(std::size_t num_args) {
List<Index> ret{};
for (std::size_t i = 0; i < num_args; ++i) {
Index index{UniqueId::New()};
ret->emplace_back(index);
}
return ret;
}
void GenerateDots() {
for (std::size_t i = 0; i < in_tensors_ranks_.size(); ++i) {
Equal(GetInIndex(i), IndexDot(GetInIteratorTuple(i), GetInDimTuple(i)));
}
for (std::size_t i = 0; i < out_tensors_ranks_.size(); ++i) {
Equal(GetOutIndex(i),
IndexDot(GetOutIteratorTuple(i), GetOutDimTuple(i)));
}
}
template <typename T>
void Equal(const T& lhs, const T& rhs) {
equations_->emplace_back(Identity<tOut<T>, tIn<T>>(lhs, rhs));
equations_->emplace_back(Identity<tOut<T>, tIn<T>>(rhs, lhs));
}
static std::optional<std::size_t> FindPos(const List<Index>& vector,
const Index& index) {
for (std::size_t i = 0; i < vector->size(); ++i) {
if (vector->at(i) == index) {
return i;
}
}
return std::nullopt;
}
static std::optional<ArgDimPosDescriptor> FindArgDimPos(
const std::vector<DimTuple>& dim_tuples, const Dim& dim) {
for (std::size_t i = 0; i < dim_tuples.size(); ++i) {
for (std::size_t j = 0; j < dim_tuples.at(i)->size(); ++j) {
if (dim_tuples.at(i)->at(j) == dim) {
return ArgDimPosDescriptor{i, j};
}
}
}
return std::nullopt;
}
const utils::Attribute& GetAttribute(const std::string& name) const {
const auto& iter = attr_map_type_.find(name);
CHECK(iter != attr_map_type_.end())
<< "Can't find Attribute with this name";
return iter->second;
}
std::vector<std::uint64_t> in_tensors_ranks_;
std::vector<std::uint64_t> out_tensors_ranks_;
GetArgStaticDimT GetInDim_;
GetArgStaticDimT GetOutDim_;
Equations equations_;
const cinn::utils::AttributeMap attr_map_type_;
FakeOpPlaceHolder fake_op_placeholder_;
std::vector<IteratorTuple> in_iterator_tuples_;
std::vector<IteratorTuple> out_iterator_tuples_;
std::vector<DimTuple> in_dim_tuples_;
std::vector<DimTuple> out_dim_tuples_;
List<Index> in_indexes_;
List<Index> out_indexes_;
};
std::function<std::shared_ptr<config::NaiveOpEquationContext>(const OpStmt&)>
GenerateContext4LocalOpStmt(const List<OpStmt>& op_stmts);
} // namespace cinn::adt::config
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/direction_equation_generator.h"
namespace cinn::adt {
class NoCtrlDirectionEquationGenerator final
: public DirectionEquationGenerator {
public:
NoCtrlDirectionEquationGenerator(const NoCtrlDirectionEquationGenerator&) =
delete;
NoCtrlDirectionEquationGenerator(NoCtrlDirectionEquationGenerator&&) = delete;
NoCtrlDirectionEquationGenerator();
Equations GetDirectionEquations() const override;
std::function<const OpStmt*(const FakeOpPlaceHolder&)>
MakeGetterOpStmt4OpPlaceHolder() const override;
std::optional<Index> OutMsgIndex4InMsgIndex(
const Index& index) const override;
private:
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/inline_translator_trait.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/tree.h"
namespace cinn::adt {
template <template <typename> class MapT,
template <typename>
class OpCallT,
typename TensorT>
struct NoInlineTranslator final {
using SrcLeaf = Store<TensorT, OpCallT<Load<TensorT>>>;
using OpExpr = Tree<OpCallT, Load<TensorT>>;
using DstLeaf = Store<TensorT, OpExpr>;
using SrcTree = Tree<MapT, SrcLeaf>;
using DstTree = Tree<MapT, DstLeaf>;
static DstTree Call(const SrcTree& src_tree) { return Translate(src_tree); }
private:
static DstTree Translate(const SrcTree& src_tree) {
return std::visit([&](const auto& impl) { return TranslateImpl(impl); },
src_tree.variant());
}
static DstTree TranslateImpl(const MapT<SrcTree>& src_map) {
return DstTree{TranslateMap(src_map)};
}
static MapT<DstTree> TranslateMap(const MapT<SrcTree>& src_map) {
const List<SrcTree> src_children =
InlineTranslatorTrait<MapT>::GetTreeInnerNodeChildren(src_map);
const List<DstTree> dst_children = TranslateList(src_children);
return InlineTranslatorTrait<MapT>::ConvertMap(src_map, dst_children);
}
static List<DstTree> TranslateList(const List<SrcTree>& src_children) {
List<DstTree> ret{};
for (const auto& src_child : *src_children) {
ret->emplace_back(Translate(src_child));
}
return ret;
}
static DstTree TranslateImpl(const SrcLeaf& src_leaf) {
return DstTree{TranslateLeaf(src_leaf)};
}
// using SrcLeaf = Store<TensorT, OpCallT<Load<TensorT>>>;
// using DstLeaf = Store<TensorT, OpExpr>;
static DstLeaf TranslateLeaf(const SrcLeaf& src_leaf) {
const auto& [tensor, op_call] = src_leaf.tuple();
const List<Load<TensorT>>& src_loads =
InlineTranslatorTrait<OpCallT>::GetTreeInnerNodeChildren(op_call);
List<OpExpr> dst_loads{};
for (const auto& src_load : *src_loads) {
dst_loads->emplace_back(src_load);
}
OpCallT<OpExpr> dst_op_call =
InlineTranslatorTrait<OpCallT>::ConvertMap(op_call, dst_loads);
OpExpr dst_op_call_tree = dst_op_call;
return DstLeaf{tensor, dst_op_call_tree};
}
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
namespace cinn::adt {
struct ArgDimPosDescriptor {
ArgDimPosDescriptor(std::size_t t_idx, std::size_t d_idx)
: tensor_idx(t_idx), dim_idx(d_idx) {}
std::size_t tensor_idx;
std::size_t dim_idx;
};
DEFINE_ADT_UNION(OpArgPos, Undefined, tIn<std::size_t>, tOut<std::size_t>);
DEFINE_ADT_UNION(OpArgDimPos,
Undefined,
tIn<ArgDimPosDescriptor>,
tOut<ArgDimPosDescriptor>);
} // namespace cinn::adt
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "glog/logging.h"
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/hlir/framework/node.h"
namespace cinn::adt::config {
class OpEquationContext {
public:
OpEquationContext(const OpEquationContext&) = delete;
OpEquationContext(OpEquationContext&&) = delete;
virtual ~OpEquationContext() {}
virtual const std::vector<std::uint64_t>& GetInTensorsRanks() const = 0;
virtual const std::vector<std::uint64_t>& GetOutTensorsRanks() const = 0;
virtual void Equal(const Iterator& lhs, const Iterator& rhs) = 0;
virtual void Equal(const Index& lhs, const Index& rhs) = 0;
virtual void Equal(const IteratorTuple& lhs, const IteratorTuple& rhs) = 0;
virtual Iterator GetBroadcastedInputIterator(const Iterator& out_iterator,
const Dim& dim) = 0;
virtual const IteratorTuple& GetInIteratorTuple(
std::size_t input_idx) const = 0;
virtual const IteratorTuple& GetOutIteratorTuple(
std::size_t output_idx) const = 0;
virtual const Index& GetInIndex(std::size_t input_idx) const = 0;
virtual const Index& GetOutIndex(std::size_t output_idx) const = 0;
virtual const DimTuple& GetInDimTuple(std::size_t input_idx) const = 0;
virtual const DimTuple& GetOutDimTuple(std::size_t output_idx) const = 0;
template <typename T>
const T& Attr(const std::string& name) const {
return absl::get<T>(GetAttribute(name));
}
protected:
OpEquationContext() = default;
virtual const utils::Attribute& GetAttribute(
const std::string& name) const = 0;
};
} // namespace cinn::adt::config
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/direction_equation_generator.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/equation_util.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/partition_op_stmts.h"
#include "paddle/cinn/adt/print.h"
namespace cinn::adt {
AnchorIndex PickThenEraseAnchorIndex(
std::vector<AnchorIndex>* candidate_anchor_indexes) {
AnchorIndex ret = candidate_anchor_indexes->back();
candidate_anchor_indexes->pop_back();
return ret;
}
std::vector<AnchorIndex> InitCandidateAnchorIndex(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts) {
std::vector<AnchorIndex> ret{};
for (const auto& op_stmt : *op_stmts) {
const auto& equation_ctx = EquationCtx4OpStmt(op_stmt);
equation_ctx->VisitEachTensorIndex(
[&](const auto& tensor_index) { ret.emplace_back(tensor_index); });
}
return ret;
}
std::pair<std::optional<OpStmt>, List<OpStmt>> FindVisitedOpStmts(
const AnchorIndex& anchor_index,
const GraphView& equation_graph,
const std::function<const OpStmt*(const FakeOpPlaceHolder&)>&
OpStmt4OpPlaceHolder,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
std::unordered_set<Variable>* visited_variables,
std::unordered_set<const void*>* visited_functions) {
std::optional<OpStmt> opt_anchor_op_stmt{std::nullopt};
List<OpStmt> visited_op_stmts{};
const auto& TrySetAnchorOpStmt = [&](const auto& op_stmt) {
const auto& op_arg_pos =
EquationCtx4OpStmt(op_stmt)->GetOpArgPos(anchor_index);
const bool valid = !op_arg_pos.template Has<Undefined>();
if (valid) {
CHECK(!opt_anchor_op_stmt.has_value());
opt_anchor_op_stmt = op_stmt;
}
};
const auto& DoEach = [&](const Variable variable) {
if (visited_variables != nullptr) {
visited_variables->insert(variable);
}
if (variable.Has<FakeOpPlaceHolder>()) {
const auto& fake_op_placeholder = variable.Get<FakeOpPlaceHolder>();
const auto& op_stmt = *OpStmt4OpPlaceHolder(fake_op_placeholder);
visited_op_stmts->emplace_back(op_stmt);
TrySetAnchorOpStmt(op_stmt);
}
};
const auto& DoEachFunction = [&](const Function* function) {
if (visited_functions != nullptr) {
visited_functions->insert(GetFunctionDataPtr(*function));
}
};
std::array<AnchorIndex, 1> starts{anchor_index};
equation_graph(starts.begin(), starts.end(), DoEach, DoEachFunction);
return std::pair{opt_anchor_op_stmt, visited_op_stmts};
}
template <typename DoEachT>
void VisitEachEquation(const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const DoEachT& DoEach) {
for (const auto& op_stmt : *op_stmts) {
const auto& ctx = EquationCtx4OpStmt(op_stmt);
ctx->VisitEachEquation(DoEach);
}
}
GraphView MakeOpsGraphViewForPartition(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts) {
Equations equations{};
VisitEachEquation(op_stmts, EquationCtx4OpStmt, [&](const auto& equation) {
equations->emplace_back(equation);
});
return Graph::New(equations)->GetGraphView();
}
template <typename DoEachT>
void VisitEachIndexAndAsOutput(const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const DoEachT& DoEach) {
for (const auto& op_stmt : *op_stmts) {
const auto& ctx = EquationCtx4OpStmt(op_stmt);
ctx->VisitEachInputTensorIndex(
[&](const auto& index) { DoEach(op_stmt, index, tOut<bool>{false}); });
ctx->VisitEachOutputTensorIndex(
[&](const auto& index) { DoEach(op_stmt, index, tOut<bool>{true}); });
}
}
void MakeGetters4Indexes(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator,
std::function<tOut<bool>(const Index&)>* AsOutput4Index,
std::function<Index(const Index&)>* OutMsgIndex4InMsgIndex) {
using Index2AsOutput = std::unordered_map<Index, tOut<bool>>;
const auto& index2as_output = std::make_shared<Index2AsOutput>();
const auto& UpdateCaches =
[&](const auto& op_stmt, const auto& index, const auto& as_output) {
CHECK(index2as_output->emplace(index, as_output).second);
};
VisitEachIndexAndAsOutput(op_stmts, EquationCtx4OpStmt, UpdateCaches);
*AsOutput4Index = [index2as_output](const Index& index) {
return index2as_output->at(index);
};
*OutMsgIndex4InMsgIndex =
[direction_equation_generator](const Index& index) -> Index {
const auto& out_msg_index =
direction_equation_generator->OutMsgIndex4InMsgIndex(index);
CHECK(out_msg_index.has_value());
return out_msg_index.value();
};
}
std::unordered_map<Tensor, std::vector<Index>> GenerateSameTensor2Indexes(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt) {
std::unordered_map<Tensor, std::vector<Index>> tensor2indexes;
for (const auto& op_stmt : *op_stmts) {
const auto& ctx = EquationCtx4OpStmt(op_stmt);
const auto& [op, op_inputs, op_outputs] = op_stmt.tuple();
for (std::size_t idx = 0; idx < op_inputs.value()->size(); ++idx) {
tensor2indexes[op_inputs.value()->at(idx)].emplace_back(
ctx->GetInIndex(idx));
}
for (std::size_t idx = 0; idx < op_outputs.value()->size(); ++idx) {
tensor2indexes[op_outputs.value()->at(idx)].emplace_back(
ctx->GetOutIndex(idx));
}
}
return tensor2indexes;
}
template <typename DoEachT>
void VisitIndexesOfSameTensor(const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const DoEachT& DoEach) {
const auto& tensor2indexes =
GenerateSameTensor2Indexes(op_stmts, EquationCtx4OpStmt);
for (const auto& [tensor, indexes] : tensor2indexes) {
DoEach(indexes);
}
}
// DoEachT is like void(*)(Index producer_index, Index
// consumer_index)
template <typename AsOutput4IndexT, typename DoEachT>
void VisitProducerConsumerPair(const std::vector<Index>& tensor_indexes,
const AsOutput4IndexT& AsOutput4Index,
const DoEachT& DoEach) {
CHECK(!tensor_indexes.empty());
if (AsOutput4Index(tensor_indexes.at(0)).value()) { // Write first
auto producer = tensor_indexes.at(0);
for (std::size_t idx = 1; idx < tensor_indexes.size(); ++idx) {
DoEach(producer, tensor_indexes.at(idx));
if (AsOutput4Index(tensor_indexes.at(idx)).value()) {
producer = tensor_indexes.at(idx);
}
}
} else {
for (const auto& tensor_index : tensor_indexes) { // Read first
CHECK(!AsOutput4Index(tensor_index).value());
}
}
}
// DoEachT is like void(*)(Index producer_index, Index
// consumer_index)
template <typename AsOutput4IndexT, typename DoEachT>
void VisitProducerConsumerTensorIndexPair(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const AsOutput4IndexT& AsOutput4Index,
const DoEachT& DoEach) {
VisitIndexesOfSameTensor(
op_stmts, EquationCtx4OpStmt, [&](const auto& indexes) {
VisitProducerConsumerPair(indexes, AsOutput4Index, DoEach);
});
}
void CollectIdentity(const Index& in_tensor_index,
const Index& out_tensor_index,
Equations* equations) {
IdentityConnect(out_tensor_index, in_tensor_index, equations);
}
GraphView MakeParametersGraphViewForPartition(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator) {
Equations equations{};
std::function<tOut<bool>(const Index&)> AsOutput4Index{};
std::function<Index(const Index&)> OutMsgIndex4InMsgIndex{};
MakeGetters4Indexes(op_stmts,
EquationCtx4OpStmt,
direction_equation_generator,
&AsOutput4Index,
&OutMsgIndex4InMsgIndex);
const auto& CollectEquation = [&](const auto& producer_index,
const auto& consumer_index) {
CollectIdentity(
OutMsgIndex4InMsgIndex(producer_index), consumer_index, &equations);
CollectIdentity(
OutMsgIndex4InMsgIndex(consumer_index), producer_index, &equations);
};
VisitProducerConsumerTensorIndexPair(
op_stmts, EquationCtx4OpStmt, AsOutput4Index, CollectEquation);
return Graph::New(equations)->GetGraphView();
}
GraphView MakeGlobalEquationGraphViewForPartition(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator) {
const auto& ops_graph_view =
MakeOpsGraphViewForPartition(EquationCtx4OpStmt, op_stmts);
const auto& direction_equation_view =
Graph::New(direction_equation_generator->GetDirectionEquations())
->GetGraphView();
const auto& parameters_graph_view = MakeParametersGraphViewForPartition(
EquationCtx4OpStmt, op_stmts, direction_equation_generator);
return ops_graph_view.Merge(direction_equation_view)
.Merge(parameters_graph_view);
}
template <typename DoEachT>
void VisitTensorIndex(const AnchorGroup& igroup_spec, const DoEachT& DoEach) {
const auto& op_stmts = igroup_spec.op_stmts;
const auto& EquationCtx4OpStmt = igroup_spec.EquationCtx4OpStmt;
for (const auto& igroup_op_stmt : *op_stmts) {
const auto& ctx = EquationCtx4OpStmt(igroup_op_stmt);
ctx->VisitEachTensorIndex(DoEach);
}
}
void CleanSmallAnchorGroups(
const AnchorGroup& igroup_spec,
std::unordered_map<AnchorIndex, AnchorGroup>* anchor_index2igroup_spec) {
VisitTensorIndex(igroup_spec, [&](const auto& tensor_index) {
anchor_index2igroup_spec->erase(tensor_index);
});
}
void UpdataAnchorIndex2AnchorGroup(
const AnchorGroup& igroup_spec,
std::unordered_map<AnchorIndex, AnchorGroup>* anchor_index2igroup_spec) {
CleanSmallAnchorGroups(igroup_spec, anchor_index2igroup_spec);
CHECK(anchor_index2igroup_spec->emplace(igroup_spec.anchor_index, igroup_spec)
.second);
}
void EraseCandidateAnchorIndexes(
const AnchorGroup& igroup_spec,
std::vector<AnchorIndex>* candidate_anchor_indexes) {
VisitTensorIndex(igroup_spec, [&](const auto& tensor_index) {
auto iter = std::find(candidate_anchor_indexes->begin(),
candidate_anchor_indexes->end(),
tensor_index);
while (iter != candidate_anchor_indexes->end()) {
candidate_anchor_indexes->erase(iter);
iter = std::find(candidate_anchor_indexes->begin(),
candidate_anchor_indexes->end(),
tensor_index);
}
});
}
std::unordered_map<AnchorIndex, AnchorGroup> PartitionOpStmtsIntoAnchorGroups(
std::vector<AnchorIndex>* candidate_anchor_indexes,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator) {
CHECK(!op_stmts->empty());
std::unordered_map<AnchorIndex, AnchorGroup> anchor_index2igroup_spec{};
const auto& OpStmt4OpPlaceHolder =
direction_equation_generator->MakeGetterOpStmt4OpPlaceHolder();
const auto& equation_graph_view = MakeGlobalEquationGraphViewForPartition(
EquationCtx4OpStmt, op_stmts, direction_equation_generator);
std::unordered_set<OpStmt> all_visited_op_stmts{};
while (!candidate_anchor_indexes->empty()) {
AnchorIndex anchor_index =
PickThenEraseAnchorIndex(candidate_anchor_indexes);
const auto& [opt_anchor_op_stmt, visited_op_stmts] =
FindVisitedOpStmts(anchor_index,
equation_graph_view,
OpStmt4OpPlaceHolder,
EquationCtx4OpStmt,
/*visited_variables=*/nullptr,
/*visited_functions=*/nullptr);
if (visited_op_stmts->empty()) {
continue;
}
CHECK(opt_anchor_op_stmt.has_value());
all_visited_op_stmts.insert(visited_op_stmts->begin(),
visited_op_stmts->end());
AnchorGroup igroup_spec{anchor_index,
opt_anchor_op_stmt.value(),
visited_op_stmts,
EquationCtx4OpStmt};
UpdataAnchorIndex2AnchorGroup(igroup_spec, &anchor_index2igroup_spec);
EraseCandidateAnchorIndexes(igroup_spec, candidate_anchor_indexes);
}
CHECK_EQ(all_visited_op_stmts.size(), op_stmts->size())
<< "Some fake_op_placeholders are not visited";
return anchor_index2igroup_spec;
}
void AnchorGroup::PrintEquations() const {
const auto& ctx = EquationCtx4OpStmt(op_stmt);
VLOG(1) << "anchor_index: ";
VLOG(1) << ToTxtString(anchor_index);
VLOG(1) << "AnchorGroup.equations: ";
// ctx->Print();
}
std::unordered_map<Variable, const Value> MakeAnchorIndex2Ok(
const AnchorGroup& igroup_spec) {
return {{igroup_spec.anchor_index, Ok{}}};
}
template <typename DoEachT>
tBreak<bool> AggregateAnchorGroupOpStmt(const AnchorGroup& igroup_spec,
const DoEachT& DoEach) {
for (const auto& op_stmt : *igroup_spec.op_stmts) {
tBreak<bool> ret = DoEach(op_stmt);
if (ret.value()) {
return ret;
}
}
return tBreak<bool>{false};
}
void CheckEquationSolvable(
const AnchorGroup& igroup_spec,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator) {
const auto& equation_graph_view =
MakeGlobalEquationGraphViewForPartition(igroup_spec.EquationCtx4OpStmt,
igroup_spec.op_stmts,
direction_equation_generator);
const auto& init_var2value = MakeAnchorIndex2Ok(igroup_spec);
IndexExprInferContext ctx{init_var2value, constants_provider};
const auto& IsOpSolved = [&](const auto& op_stmt) {
const auto& equation_ctx = *igroup_spec.EquationCtx4OpStmt(op_stmt);
const auto& fake_op_placeholder = equation_ctx.fake_op_placeholder();
return ctx.HasValue(fake_op_placeholder);
};
CheckEquationsSolvable(equation_graph_view, igroup_spec.anchor_index, &ctx);
AggregateAnchorGroupOpStmt(igroup_spec, [&](const auto& op_stmt) {
CHECK(IsOpSolved(op_stmt));
return tBreak<bool>{false};
});
}
std::function<std::size_t(const OpStmt&)> MakeGetterOrderValue4OpStmt(
const List<OpStmt>& op_stmts) {
using OpStmt2OrderValue = std::unordered_map<OpStmt, std::size_t>;
const auto& op_stmt2order_value = std::make_shared<OpStmt2OrderValue>();
for (std::size_t idx = 0; idx < op_stmts->size(); ++idx) {
CHECK(op_stmt2order_value->emplace(op_stmts->at(idx), idx).second);
}
return [op_stmt2order_value](const auto& op_stmt) {
return op_stmt2order_value->at(op_stmt);
};
}
template <typename DoEachT>
void VisitEachAnchorGroup(
std::unordered_map<AnchorIndex, AnchorGroup>* anchor_index2igroup_spec,
const DoEachT& DoEach) {
for (auto& [anchor_index, igroup_spec] : *anchor_index2igroup_spec) {
DoEach(&igroup_spec);
}
}
template <typename DoEachT>
void VisitEachAnchorGroup(const std::unordered_map<AnchorIndex, AnchorGroup>&
anchor_index2igroup_spec,
const DoEachT& DoEach) {
for (const auto& [anchor_index, igroup_spec] : anchor_index2igroup_spec) {
DoEach(igroup_spec);
}
}
void SortAnchorGroupOpStmts(
std::unordered_map<AnchorIndex, AnchorGroup>* anchor_index2igroup_spec,
const std::function<std::size_t(const OpStmt&)>& OrderValue4OpStmt) {
const auto& CompareOpStmt = [&](const auto& lhs, const auto& rhs) {
return OrderValue4OpStmt(lhs) < OrderValue4OpStmt(rhs);
};
VisitEachAnchorGroup(anchor_index2igroup_spec, [&](auto* igroup_spec) {
std::sort(igroup_spec->op_stmts->begin(),
igroup_spec->op_stmts->end(),
CompareOpStmt);
});
}
std::vector<AnchorGroup> SortedAnchorGroups(
const std::unordered_map<AnchorIndex, AnchorGroup>&
anchor_index2igroup_spec,
const std::function<std::size_t(const OpStmt&)>& OrderValue4OpStmt) {
std::vector<AnchorGroup> ret{};
VisitEachAnchorGroup(anchor_index2igroup_spec, [&](const auto& igroup_spec) {
ret.emplace_back(igroup_spec);
});
const auto& OrderValue4AnchorGroup = [&](const AnchorGroup& igroup_spec) {
return OrderValue4OpStmt(igroup_spec.op_stmts->back());
};
std::sort(ret.begin(), ret.end(), [&](const auto& lhs, const auto& rhs) {
return OrderValue4AnchorGroup(lhs) < OrderValue4AnchorGroup(rhs);
});
return ret;
}
std::vector<AnchorGroup> SortedAnchorGroups(
std::unordered_map<AnchorIndex, AnchorGroup>* anchor_index2igroup_spec,
const List<OpStmt>& op_stmts) {
const auto& OrderValue4OpStmt = MakeGetterOrderValue4OpStmt(op_stmts);
SortAnchorGroupOpStmts(anchor_index2igroup_spec, OrderValue4OpStmt);
return SortedAnchorGroups(*anchor_index2igroup_spec, OrderValue4OpStmt);
}
std::vector<AnchorGroup> PartitionOpStmts(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator) {
std::vector<AnchorIndex> candidate_anchor_indexes =
InitCandidateAnchorIndex(EquationCtx4OpStmt, op_stmts);
std::unordered_map<AnchorIndex, AnchorGroup> anchor_index2igroup_spec =
PartitionOpStmtsIntoAnchorGroups(&candidate_anchor_indexes,
EquationCtx4OpStmt,
op_stmts,
direction_equation_generator);
return SortedAnchorGroups(&anchor_index2igroup_spec, op_stmts);
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation_function_constants_provider.h"
#include "paddle/cinn/adt/equation_graph.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
#include "paddle/cinn/hlir/framework/graph.h"
namespace cinn::adt {
class EquationFunctionConstantsProvider;
class DirectionEquationGenerator;
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
using AnchorIndex = Index;
struct AnchorGroup {
AnchorIndex anchor_index;
OpStmt op_stmt;
List<OpStmt> op_stmts;
EquationCtx4OpStmtT EquationCtx4OpStmt;
void PrintEquations() const;
};
std::vector<AnchorGroup> PartitionOpStmts(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator);
void CheckEquationSolvable(
const AnchorGroup& igroup_spec,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constant_provider,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator);
GraphView MakeGlobalEquationGraphViewForPartition(
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const List<OpStmt>& op_stmts,
const std::shared_ptr<DirectionEquationGenerator>&
direction_equation_generator);
} // namespace cinn::adt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment