Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <optional>
#include "paddle/cinn/adt/equation_constant.h"
namespace cinn::adt {
class EquationFunctionConstantsProvider {
public:
virtual ~EquationFunctionConstantsProvider() = default;
virtual Constant GetDimSize(const Dim& dim) const = 0;
virtual bool AddDim(const Dim& dim, const Constant& dim_value) = 0;
protected:
EquationFunctionConstantsProvider() = default;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/common/equation_graph_topo_walker.h"
namespace cinn::adt {
using Functions = Equations;
// clang-format off
/*
Graph = ([Variale], [Function], [Edge Variable Function], [Edge Function Variable])
Edge T0 T1 = (T0, T1)
*/
// clang-format on
class Graph final : public std::enable_shared_from_this<Graph> {
public:
using V2Fs = std::unordered_map<Variable, std::vector<const Function*>>;
using F2Vs = std::unordered_map<const Function*, std::vector<Variable>>;
static std::shared_ptr<Graph> New(const Functions& equations) {
return std::shared_ptr<Graph>{new Graph{equations}};
}
using VariableVisitorT = std::function<void(const Variable)>;
using FunctionVisitorT = std::function<void(const Function*)>;
using F4VVisitor =
std::function<void(const Variable, const FunctionVisitorT&)>;
using V4FVisitor =
std::function<void(const Function*, const VariableVisitorT&)>;
static F4VVisitor Merge(const F4VVisitor& lhs, const F4VVisitor& rhs) {
return [=](const Variable variable,
const std::function<void(const Function*)>& Visit) {
lhs(variable, Visit);
rhs(variable, Visit);
};
}
static V4FVisitor Merge(const V4FVisitor& lhs, const V4FVisitor& rhs) {
return [=](const Function* function,
const std::function<void(const Variable)>& Visit) {
lhs(function, Visit);
rhs(function, Visit);
};
}
F4VVisitor GetNextFunctionsVisitor() const {
auto self = this->shared_from_this();
return [self](const Variable variable,
const std::function<void(const Function*)>& Visit) {
const auto iter = self->variable2next_functions_->find(variable);
if (iter == self->variable2next_functions_->end()) {
return;
}
for (const Function* function : iter->second) {
Visit(function);
}
};
}
V4FVisitor GetInputVariablesVisitor() const {
auto self = this->shared_from_this();
return [self](const Function* function,
const std::function<void(const Variable)>& Visit) {
const auto iter = self->function2in_variables_->find(function);
if (iter == self->function2in_variables_->end()) {
return;
}
for (const Variable variable : iter->second) {
Visit(variable);
}
};
}
V4FVisitor GetOutputVariablesVisitor() const {
auto self = this->shared_from_this();
return [self](const Function* function,
const std::function<void(const Variable)>& Visit) {
const auto iter = self->function2out_variables_->find(function);
if (iter == self->function2out_variables_->end()) {
return;
}
for (const Variable variable : iter->second) {
Visit(variable);
}
};
}
static EquationGraphTopoWalker<Variable, const Function*> GetMergedWalker(
const Graph& lhs, const Graph& rhs) {
return EquationGraphTopoWalker<Variable, const Function*>(
/*NextFunctionsVisitor=*/Merge(lhs.GetNextFunctionsVisitor(),
rhs.GetNextFunctionsVisitor()),
/*InputVariablesVisitor=*/
Merge(lhs.GetInputVariablesVisitor(), rhs.GetInputVariablesVisitor()),
/*OutputVariablesVisitor=*/
Merge(lhs.GetOutputVariablesVisitor(),
rhs.GetOutputVariablesVisitor()));
}
EquationGraphTopoWalker<Variable, const Function*> GetGraphView() const {
return EquationGraphTopoWalker<Variable, const Function*>(
/*NextFunctionsVisitor=*/GetNextFunctionsVisitor(),
/*InputVariablesVisitor=*/GetInputVariablesVisitor(),
/*OutputVariablesVisitor=*/GetOutputVariablesVisitor());
}
const std::unordered_set<Variable>& GetVariables() const {
return variables_;
}
const Equations& GetEquations() const { return functions_; }
private:
explicit Graph(const Functions& equations)
: functions_(equations),
variable2next_functions_(std::make_shared<V2Fs>()),
function2in_variables_(std::make_shared<F2Vs>()),
function2out_variables_(std::make_shared<F2Vs>()) {
for (const Function& function : *functions_) {
CollectVariablesAndEdges(function);
}
}
void CollectVariablesAndEdges(const Function& function) {
const auto& [in_variables, out_variables] =
CollectInputAndOutputVariables(function);
for (const Variable& variable : in_variables) {
variables_.insert(variable);
(*variable2next_functions_)[variable].push_back(&function);
(*function2in_variables_)[&function].push_back(variable);
v2f_edges_.emplace_back(std::pair{variable, &function});
}
for (const Variable& variable : out_variables) {
variables_.insert(variable);
(*function2out_variables_)[&function].push_back(variable);
f2v_edges_.emplace_back(std::pair{&function, variable});
}
}
Functions functions_;
// tNext [Function] <- Variable
std::shared_ptr<V2Fs> variable2next_functions_;
// tIn [Variable] <- Function
std::shared_ptr<F2Vs> function2in_variables_;
// tOut [Variable] <- Function
std::shared_ptr<F2Vs> function2out_variables_;
std::unordered_set<Variable> variables_;
// For debug
std::vector<std::pair<const Variable, const Function*>> v2f_edges_;
std::vector<std::pair<const Function*, const Variable>> f2v_edges_;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_map>
#include <variant>
#include "glog/logging.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/adt/simplify_value.h"
#include "paddle/cinn/adt/tags.h"
#include "paddle/cinn/common/equation_graph_topo_walker.h"
namespace cinn::adt {
std::unordered_map<Variable, Value> InferValuesImpl(
const Identity<tOut<Iterator>, tIn<Iterator>>& id,
IndexExprInferContext* ctx) {
const auto& [out_iter, in_iter] = id.tuple();
Variable in_variable{in_iter.value()};
CHECK(ctx->HasValue(in_variable));
return {{out_iter.value(), ctx->GetValue(in_variable)}};
}
std::unordered_map<Variable, Value> InferValuesImpl(
const Identity<tOut<Index>, tIn<Index>>& id, IndexExprInferContext* ctx) {
const auto& [out_index, in_index] = id.tuple();
Variable in_variable{in_index.value()};
CHECK(ctx->HasValue(in_variable));
return {{out_index.value(), ctx->GetValue(in_variable)}};
}
bool HasReplicatedValues(const List<Value>& values) {
for (std::size_t i = 0; i < values->size(); ++i) {
for (std::size_t j = i + 1; j < values->size(); ++j) {
if (values->at(i) == values->at(j)) {
return true;
}
}
}
return false;
}
std::unordered_map<Variable, Value> InferValuesImpl(
const IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>& dot,
IndexExprInferContext* ctx) {
const auto& [dims, out_index, in_iters] = dot.tuple();
List<Value> in_values;
for (const auto& iter : *in_iters.value()) {
in_values->emplace_back(ctx->GetValue(iter));
}
if (HasReplicatedValues(in_values)) {
return {{out_index.value(), Undefined{}}};
}
List<Constant> dim_constants{};
for (const auto& dim : *dims) {
dim_constants->emplace_back(dim);
}
IndexDotValue<Value, Constant> index_dot{in_values, dim_constants};
return {{out_index.value(), index_dot}};
}
std::unordered_map<Variable, Value> InferValuesImpl(
const GetBroadcastedIterator<Dim, tOut<Iterator>, tIn<Iterator>>& broadcast,
IndexExprInferContext* ctx) {
const auto& [dim, out_iterator, in_iterator] = broadcast.tuple();
BroadcastedIterator<Value, Constant> broadcast_iterator{
ctx->GetValue(in_iterator.value()), dim};
return {{out_iterator.value(), broadcast_iterator}};
}
std::unordered_map<Variable, Value> InferValuesImpl(
const IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>& undot,
IndexExprInferContext* ctx) {
const auto& [dims, out_iters, in_index] = undot.tuple();
List<Constant> dim_constants{};
for (const auto& dim : *dims) {
dim_constants->emplace_back(dim);
}
IndexUnDotValue<Value, Constant> index_undot{ctx->GetValue(in_index.value()),
dim_constants};
std::unordered_map<Variable, Value> ret{};
for (std::size_t idx = 0; idx < out_iters.value()->size(); ++idx) {
ListGetItem<Value, Constant> list_get_item{index_undot, idx};
ret.emplace(out_iters.value()->at(idx), list_get_item);
}
return ret;
}
std::unordered_map<Variable, Value> InferValuesImpl(
const InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>& in_msg2out_msg,
IndexExprInferContext* ctx) {
const auto& [op_placeholder, out_msg_indexes, in_msg_indexes] =
in_msg2out_msg.tuple();
const auto& [out_msg_in_indexes, out_msg_out_indexes] =
out_msg_indexes.value().tuple();
const auto& [in_msg_in_indexes, in_msg_out_indexes] =
in_msg_indexes.value().tuple();
std::unordered_map<Variable, Value> ret{{op_placeholder.value(), Ok{}}};
CHECK_EQ(out_msg_in_indexes.value()->size(),
in_msg_in_indexes.value()->size());
CHECK_EQ(out_msg_out_indexes.value()->size(),
in_msg_out_indexes.value()->size());
for (std::size_t i = 0; i < out_msg_in_indexes.value()->size(); ++i) {
const auto& value = ctx->GetValue(in_msg_in_indexes.value()->at(i));
CHECK(ret.emplace(out_msg_in_indexes.value()->at(i), value).second);
}
for (std::size_t i = 0; i < out_msg_out_indexes.value()->size(); ++i) {
const auto& value = ctx->GetValue(in_msg_out_indexes.value()->at(i));
const auto& out_index = out_msg_out_indexes.value()->at(i);
if (out_index.has_value()) {
CHECK(ret.emplace(out_index.value(), value).second);
}
}
return ret;
}
std::unordered_map<Variable, Value> InferValuesImpl(
const ConstantFunction<tOut<Iterator>, tIn<Index>>& constant_function,
IndexExprInferContext* ctx) {
const auto& [out_iter, in_index, constant] = constant_function.tuple();
return std::unordered_map<Variable, Value>{{out_iter.value(), constant}};
}
std::unordered_map<Variable, Value> InferValues(const Function* function,
IndexExprInferContext* ctx) {
return std::visit(
[&](auto&& function) { return InferValuesImpl(function, ctx); },
function->variant());
}
DEFINE_ADT_TAG(tValueInferSuccess);
template <typename OnFailT>
tValueInferSuccess<bool> MergeInferedValuesIntoCtx(const Function* function,
IndexExprInferContext* ctx,
const OnFailT& OnFail) {
auto output_variable2value = InferValues(function, ctx);
for (const auto& [variable, unsimplified_value] : output_variable2value) {
Value simplified_value({SimplifyValue(unsimplified_value, *ctx)});
if (simplified_value.Has<Undefined>()) {
return OnFail(std::optional<Value>{std::nullopt}, simplified_value);
}
if (!ctx->HasValue(variable)) {
ctx->SetValue(variable, simplified_value);
} else {
std::optional<Value> opt_old_value = ctx->GetValue(variable);
if (simplified_value != opt_old_value.value()) {
return OnFail(opt_old_value, simplified_value);
}
}
}
return tValueInferSuccess<bool>{true};
}
tValueInferSuccess<bool> MergeInferedValuesIntoCtx(const Function* function,
IndexExprInferContext* ctx) {
return MergeInferedValuesIntoCtx(
function, ctx, [&](const std::optional<Value>& lhs, const Value& rhs) {
if (lhs.has_value()) {
VLOG(1) << "opt_old_value = " << ToTxtString(lhs.value());
}
VLOG(1) << "simplified = " << ToTxtString(rhs);
return tValueInferSuccess<bool>{false};
});
}
void SolveEquations(
const EquationGraphTopoWalker<Variable, const Function*>& walker,
const std::vector<Variable>& starts,
IndexExprInferContext* ctx) {
walker.WalkFunction(
starts.begin(), starts.end(), [&](const Function* function) {
tValueInferSuccess<bool> has_unique_value =
MergeInferedValuesIntoCtx(function, ctx);
CHECK(has_unique_value.value());
});
}
void CheckEquationsSolvable(
const EquationGraphTopoWalker<Variable, const Function*>& walker,
const Variable& start,
IndexExprInferContext* ctx) {
const auto& CheckNoConflictInferedValue = [&](const Function* function) {
MergeInferedValuesIntoCtx(
function,
ctx,
[&](const auto& opt_old_value, const auto& simplified_value) {
LOG(ERROR) << "old_value: " << ToTxtString(opt_old_value);
LOG(ERROR) << "simplified_value: " << ToTxtString(simplified_value);
LOG(FATAL) << "CheckEquationsSolvable Failed";
return tValueInferSuccess<bool>{false};
});
};
walker.WalkFunction(start, CheckNoConflictInferedValue);
}
tHasNoConflictValue<bool> TrySolveEquations(
const EquationGraphTopoWalker<Variable, const Function*>& walker,
const Variable& start,
IndexExprInferContext* ctx) {
bool has_no_conflict_value = true;
const auto& HasConflictInferedValue = [&](const Function* function) {
tValueInferSuccess<bool> has_unique_value =
MergeInferedValuesIntoCtx(function, ctx);
return !has_unique_value.value();
};
walker.WalkFunction(start, [&](const Function* function) {
if (has_no_conflict_value && HasConflictInferedValue(function)) {
has_no_conflict_value = false;
}
});
return tHasNoConflictValue<bool>{has_no_conflict_value};
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/tags.h"
namespace cinn::adt {
class IndexExprInferContext;
void SolveEquations(
const EquationGraphTopoWalker<Variable, const Function*>& walker,
const std::vector<Variable>& starts,
IndexExprInferContext* ctx);
tHasNoConflictValue<bool> TrySolveEquations(
const EquationGraphTopoWalker<Variable, const Function*>& walker,
const Variable& start,
IndexExprInferContext* ctx);
void CheckEquationsSolvable(
const EquationGraphTopoWalker<Variable, const Function*>& walker,
const Variable& start,
IndexExprInferContext* ctx);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <optional>
#include <unordered_map>
#include <unordered_set>
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/common/equation_graph_topo_walker.h"
#include "paddle/cinn/common/topo_walker.h"
namespace cinn::adt {
template <typename VT, typename FT>
common::TopoWalker<FT> GetDAGTopoWalker(
const EquationGraphTopoWalker<VT, FT>& eg_walker, const VT& start) {
auto var2solo_producer =
std::make_shared<std::unordered_map<VT, std::optional<FT>>>();
var2solo_producer->emplace_back(start, std::nullopt);
eg_walker(start, [&](FT function) {
eg_walker.VisitOutputVariables(function, [&](VT out_variable) {
var2solo_producer->emplace_back(out_variable, function);
});
});
const auto& VisitPrevNodes = [var2solo_producer, eg_walker](
FT function,
const std::function<void(FT)>& Visit) {
eg_walker.VisitInputVariables(function, [&](VT in_variable) {
const auto& opt_producer = var2solo_producer->at(in_variable);
if (opt_producer.has_value()) {
Visit(opt_producer.value());
}
});
};
const auto& VisitNextNodes = [var2solo_producer, eg_walker](
FT function,
const std::function<void(FT)>& Visit) {
eg_walker.VisitOutputVariables(function, [&](VT out_variable) {
const auto& opt_producer = var2solo_producer->at(out_variable);
if (opt_producer.has_value() && opt_producer.value() == function) {
eg_walker.VisitNextFunctions(out_variable, Visit);
} else {
// Do nothing
}
});
};
return common::TopoWalker<FT>(VisitPrevNodes, VisitNextNodes);
}
template <typename VT, typename FT>
EquationGraphTopoWalker<VT, FT> GetSubgraph(
const EquationGraphTopoWalker<VT, FT>& graph,
const std::function<bool(FT)>& IsSelected) {
const auto& VisitNextFunctions =
[graph, IsSelected](VT variable, const std::function<void(FT)>& Visit) {
graph.VisitNextFunctions(variable, [&](FT out_function) {
if (IsSelected(out_function)) {
Visit(out_function);
}
});
};
const auto& VisitInputVariables =
[graph, IsSelected](FT function, const std::function<void(VT)>& Visit) {
CHECK(IsSelected(function));
graph.VisitInputVariables(function, Visit);
};
const auto& VisitOutputVariables =
[graph, IsSelected](FT function, const std::function<void(VT)>& Visit) {
CHECK(IsSelected(function));
graph.VisitOutputVariables(function, Visit);
};
return EquationGraphTopoWalker<VT, FT>(
VisitNextFunctions, VisitInputVariables, VisitOutputVariables);
}
inline List<Dim> MakeDims(std::size_t num_dims) {
List<Dim> ret{};
for (std::size_t i = 0; i < num_dims; ++i) {
ret->emplace_back(UniqueId::New());
}
return ret;
}
template <typename DoEachT>
void IdentityConnect(const Index& out, const Index& in, const DoEachT& DoEach) {
DoEach(Identity<tOut<Index>, tIn<Index>>{out, in});
}
inline void IdentityConnect(const Index& out,
const Index& in,
Equations* equations) {
IdentityConnect(out, in, [&](const auto& equation) {
(*equations)->push_back(equation);
});
}
template <typename DoEachT>
void IdentityConnect(const Iterator& out,
const Iterator& in,
const DoEachT& DoEach) {
DoEach(Identity<tOut<Iterator>, tIn<Iterator>>{out, in});
}
inline void IdentityConnect(const Iterator& out,
const Iterator& in,
Equations* equations) {
IdentityConnect(out, in, [&](const auto& equation) {
(*equations)->push_back(equation);
});
}
template <typename DoEachT>
void Equal(const Index& lhs, const Index& rhs, const DoEachT& DoEach) {
IdentityConnect(lhs, rhs, DoEach);
IdentityConnect(rhs, lhs, DoEach);
}
inline void Equal(const Index& lhs, const Index& rhs, Equations* equations) {
Equal(lhs, rhs, [&](const auto& equation) {
(*equations)->emplace_back(equation);
});
}
template <typename DoEachT>
void Equal(const Iterator& lhs, const Iterator& rhs, const DoEachT& DoEach) {
IdentityConnect(lhs, rhs, DoEach);
IdentityConnect(rhs, lhs, DoEach);
}
inline void Equal(const Iterator& lhs,
const Iterator& rhs,
Equations* equations) {
Equal(lhs, rhs, [&](const auto& equation) {
(*equations)->emplace_back(equation);
});
}
template <typename DoEachT>
void GenerateDotEquation(const List<Iterator>& iterators,
const List<Dim>& dims,
const Index& index,
const DoEachT& DoEach) {
DoEach(IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>{
dims, index, iterators});
DoEach(IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>{
dims, iterators, index});
}
template <typename DoEachT>
Index MakeDot(const List<Iterator>& iterators,
const List<Dim>& dims,
const DoEachT& DoEach) {
Index ret{UniqueId::New()};
GenerateDotEquation(iterators, dims, ret, DoEach);
return ret;
}
inline Index MakeDot(const List<Iterator>& iterators,
const List<Dim>& dims,
Equations* equations) {
return MakeDot(iterators, dims, [&](const auto& equation) {
(*equations)->emplace_back(equation);
});
}
inline List<Iterator> MakeIterators(std::size_t num_iterators) {
List<Iterator> ret{};
for (std::size_t i = 0; i < num_iterators; ++i) {
ret->emplace_back(UniqueId::New());
}
return ret;
}
template <typename DoEachT>
List<Iterator> MakeUnDot(const Index& index,
const List<Dim>& dims,
const DoEachT& DoEach) {
List<Iterator> ret{};
GenerateDotEquation(ret, dims, index, DoEach);
return ret;
}
inline List<Iterator> MakeUnDot(const Index& index,
const List<Dim>& dims,
Equations* equations) {
return MakeUnDot(index, dims, [&](const auto& equation) {
(*equations)->emplace_back(equation);
});
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/cinn/adt/equation_value.h"
namespace cinn::adt {} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/match.h"
namespace cinn::adt {
DEFINE_ADT_TAG(tPointer);
template <typename IteratorsT, typename DimsT>
struct IndexDotValue : public Tuple<IteratorsT, DimsT> {
using Tuple<IteratorsT, DimsT>::Tuple;
const IteratorsT& GetIteratorsValue() const {
return std::get<0>(this->tuple());
}
};
template <typename IndexT, typename DimsT>
struct IndexUnDotValue : public Tuple<IndexT, DimsT> {
using Tuple<IndexT, DimsT>::Tuple;
const IndexT& GetIndexValue() const { return std::get<0>(this->tuple()); }
};
// ListGetItem T ConstantT = (T, ConstantT)
template <typename T, typename ConstantT>
struct ListGetItem final : public Tuple<T, ConstantT> {
using Tuple<T, ConstantT>::Tuple;
const T& GetList() const { return std::get<0>(this->tuple()); }
};
// PtrGetItem T = (tPointer UniqueId, T)
template <typename T>
struct PtrGetItem final : public Tuple<tPointer<UniqueId>, T> {
using Tuple<tPointer<UniqueId>, T>::Tuple;
const T& GetArg1() const { return std::get<1>(this->tuple()); }
};
template <typename ValueT, typename ConstantT>
struct BroadcastedIterator final : public Tuple<ValueT, ConstantT> {
using Tuple<ValueT, ConstantT>::Tuple;
const ValueT& GetArg0() const { return std::get<0>(this->tuple()); }
};
DEFINE_ADT_UNION(Value,
Undefined,
Ok,
Iterator,
Constant,
List<Value>,
IndexDotValue<Value, Constant>,
IndexUnDotValue<Value, Constant>,
ListGetItem<Value, Constant>,
BroadcastedIterator<Value, Constant>,
PtrGetItem<Value>);
OVERLOAD_OPERATOR_EQ_NE(Value, UnionEqual);
using IndexDot_Value_Constant = IndexDotValue<Value, Constant>;
OVERLOAD_OPERATOR_EQ_NE(IndexDot_Value_Constant, TupleEqual);
using IndexUnDot_Value_Constant = IndexUnDotValue<Value, Constant>;
OVERLOAD_OPERATOR_EQ_NE(IndexUnDot_Value_Constant, TupleEqual);
using ListGetItem_Value_Constant = ListGetItem<Value, Constant>;
OVERLOAD_OPERATOR_EQ_NE(ListGetItem_Value_Constant, TupleEqual);
using BroadcastedIterator_Value_Constant = BroadcastedIterator<Value, Constant>;
OVERLOAD_OPERATOR_EQ_NE(BroadcastedIterator_Value_Constant, TupleEqual);
OVERLOAD_OPERATOR_EQ_NE(PtrGetItem<Value>, TupleEqual);
inline std::size_t GetHashValue(const Value& value);
inline std::size_t GetHashValueImpl(const Undefined& value) { return 0; }
inline std::size_t GetHashValueImpl(const Ok& value) { return 1; }
inline std::size_t GetHashValueImpl(const Iterator& value) {
return value.value().unique_id();
}
inline std::size_t GetHashValueImpl(const Constant& value) {
return GetHashValue(value);
}
inline std::size_t GetHashValueImpl(const List<Value>& value) {
std::size_t ret = 0;
for (const auto& v : *value) {
ret = hash_combine(ret, GetHashValue(v));
}
return ret;
}
inline std::size_t GetHashValueImpl(
const IndexDotValue<Value, Constant>& value) {
const auto& [v, c] = value.tuple();
return hash_combine(GetHashValue(v), GetHashValue(c));
}
inline std::size_t GetHashValueImpl(
const IndexUnDotValue<Value, Constant>& value) {
const auto& [v, c] = value.tuple();
return hash_combine(GetHashValue(v), GetHashValue(c));
}
inline std::size_t GetHashValueImpl(const ListGetItem<Value, Constant>& value) {
const auto& [v, c] = value.tuple();
return hash_combine(GetHashValue(v), GetHashValue(c));
}
inline std::size_t GetHashValueImpl(
const BroadcastedIterator<Value, Constant>& value) {
const auto& [v, c] = value.tuple();
return hash_combine(GetHashValue(v), GetHashValue(c));
}
inline std::size_t GetHashValueImpl(const PtrGetItem<Value>& value) {
const auto& [pointer, c] = value.tuple();
return hash_combine(pointer.value().unique_id(), GetHashValue(c));
}
OVERRIDE_UNION_GET_HASH_VALUE(Value);
} // namespace cinn::adt
namespace std {
template <>
struct hash<cinn::adt::Value> {
std::size_t operator()(const cinn::adt::Value& value) const {
return GetHashValue(value);
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/match.h"
namespace cinn::adt {
template <>
struct MatchTrait<Constant, std::int64_t> final {
static constexpr int is_template = false;
};
template <>
struct MatchTrait<Constant, tDim<UniqueId>> final {
static constexpr int is_template = false;
};
template <typename T>
struct MatchTrait<Constant, List<T>> final {
using base_type = List<Constant>;
static constexpr int is_template = true;
template <template <typename> class Matcher>
static bool MatchChildren(const base_type& list) {
for (const auto& value : *list) {
if (!Matcher<Constant>::template Call<T>(value)) {
return false;
}
}
return true;
}
};
template <>
struct MatchTrait<Value, Undefined> final {
static constexpr int is_template = false;
};
template <>
struct MatchTrait<Value, Iterator> final {
static constexpr int is_template = false;
};
template <typename T>
struct MatchTrait<Value, List<T>> final {
using base_type = List<Value>;
static constexpr int is_template = true;
template <template <typename> class Matcher>
static bool MatchChildren(const base_type& list) {
for (const auto& value : *list) {
if (!Matcher<Value>::template Call<T>(value)) {
return false;
}
}
return true;
}
};
#define DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(name, type0, type1) \
template <typename T0, typename T1> \
struct MatchTrait<Value, name<T0, T1>> final { \
using base_type = name<type0, type1>; \
\
static constexpr int is_template = true; \
\
template <template <typename> class Matcher> \
static bool MatchChildren(const base_type& value) { \
return Matcher<type0>::template Call<T0>(std::get<0>(value.tuple())) && \
Matcher<type1>::template Call<T1>(std::get<1>(value.tuple())); \
} \
};
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(ListGetItem, Value, Constant);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(BroadcastedIterator, Value, Constant);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(IndexDotValue, Value, Constant);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(IndexUnDotValue, Value, Constant);
#define DEFINE_ADT_MATCH_TRAIT_EQUATION(name) \
template <typename T> \
struct MatchTrait<Value, name<T>> final { \
using base_type = name<Value>; \
\
static constexpr int is_template = true; \
\
template <template <typename> class Matcher> \
static bool MatchChildren(const base_type& value) { \
return Matcher<Value>::template Call<T>(std::get<0>(value.tuple())); \
} \
};
DEFINE_ADT_MATCH_TRAIT_EQUATION(PtrGetItem);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/equation_value_match_trait.h"
#include "gtest/gtest.h"
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/match.h"
namespace cinn::adt::test {
TEST(Match, index_undot) {
Value expr =
IndexUnDotValue<Value, Constant>{Value{Ok()}, Constant{std::int64_t(1)}};
bool ret = cinn::adt::Match<IndexUnDotValue<Value, Constant>>(expr);
ASSERT_TRUE(ret);
}
TEST(Match, index_dot) {
Value expr =
IndexDotValue<Value, Constant>{Value{Ok()}, Constant{std::int64_t(1)}};
bool ret = cinn::adt::Match<IndexDotValue<Value, Constant>>(expr);
ASSERT_TRUE(ret);
}
TEST(Match, list) {
Value expr = List<Value>{Value{Ok()}, Value{Ok()}, Value{Ok()}};
bool ret = cinn::adt::Match<List<Value>>(expr);
ASSERT_TRUE(ret);
}
TEST(Match, list_get_item) {
Value list = List<Value>{Value{Ok()}, Value{Ok()}, Value{Ok()}};
Value expr = ListGetItem<Value, Constant>{list, Constant{std::int64_t(1)}};
bool ret = cinn::adt::Match<ListGetItem<Value, std::int64_t>>(expr);
ASSERT_TRUE(ret);
}
TEST(Match, list_get_item_index_undot) {
Value undot1 =
IndexUnDotValue<Value, Constant>{Value{Ok()}, Constant{std::int64_t(1)}};
ASSERT_TRUE((cinn::adt::Match<IndexUnDotValue<Value, Constant>>(undot1)));
Value expr = ListGetItem<Value, Constant>{undot1, Constant{std::int64_t(1)}};
ASSERT_TRUE(
(cinn::adt::Match<
ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>(expr)));
}
// List<ListGetItem<IndexUnDotValue<Value>, std::int64_t>>
TEST(Match, list_list_get_item_index_undot) {
Value undot =
IndexUnDotValue<Value, Constant>{Value{Ok()}, Constant{std::int64_t(1)}};
ASSERT_TRUE((cinn::adt::Match<IndexUnDotValue<Value, Constant>>(undot)));
Value expr1 = ListGetItem<Value, Constant>{undot, Constant{std::int64_t(0)}};
ASSERT_TRUE(
(cinn::adt::Match<
ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>(expr1)));
Value expr2 = ListGetItem<Value, Constant>{undot, Constant{std::int64_t(1)}};
ASSERT_TRUE(
(cinn::adt::Match<
ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>(expr2)));
Value list = List<Value>{expr1, expr2};
ASSERT_TRUE(
(cinn::adt::Match<
List<ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>>(
list)));
}
// IndexDotValue<List<ListGetItem<IndexUnDotValue<Value>, std::int64_t>>>
TEST(Match, index_dot_list_list_get_item_index_undot) {
Value undot1 =
IndexUnDotValue<Value, Constant>{Value{Ok()}, Constant{std::int64_t(1)}};
ASSERT_TRUE((cinn::adt::Match<IndexUnDotValue<Value, Constant>>(undot1)));
Value expr1 = ListGetItem<Value, Constant>{undot1, Constant{std::int64_t(0)}};
ASSERT_TRUE(
(cinn::adt::Match<
ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>(expr1)));
Value expr2 = ListGetItem<Value, Constant>{undot1, Constant{std::int64_t(1)}};
ASSERT_TRUE(
(cinn::adt::Match<
ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>(expr2)));
Value list = List<Value>{expr1, expr2};
ASSERT_TRUE(
(cinn::adt::Match<
List<ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>>(
list)));
Value dot = IndexDotValue<Value, Constant>{list, Constant{std::int64_t(1)}};
ASSERT_TRUE(
(cinn::adt::Match<IndexDotValue<
List<ListGetItem<IndexUnDotValue<Value, Constant>, std::int64_t>>,
Constant>>(dot)));
}
} // namespace cinn::adt::test
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/tags.h"
#include "paddle/cinn/adt/unique_id.h"
namespace cinn::adt {
// Iterator = tIterator UniqueId
using Iterator = tIterator<UniqueId>;
// IteratorTuple = [Iterator]
using IteratorTuple = List<Iterator>;
// Index = tIndex UniqueId
using Index = tIndex<UniqueId>;
// FakeOpPlaceHolder = tOpPlaceHolder UniqueId
using FakeOpPlaceHolder = tOpPlaceHolder<UniqueId>;
// Variable = Iterator | Index | FakeOpPlaceHolder
DEFINE_ADT_UNION(Variable, Iterator, Index, FakeOpPlaceHolder);
OVERLOAD_OPERATOR_EQ_NE(Variable, UnionEqual);
} // namespace cinn::adt
namespace std {
template <>
struct hash<cinn::adt::Iterator> final {
std::size_t operator()(const cinn::adt::Iterator& iterator) const {
return iterator.value().unique_id();
}
};
template <>
struct hash<cinn::adt::Index> final {
std::size_t operator()(const cinn::adt::Index& index) const {
return index.value().unique_id();
}
};
template <>
struct hash<cinn::adt::FakeOpPlaceHolder> final {
std::size_t operator()(
const cinn::adt::FakeOpPlaceHolder& placeholder) const {
return placeholder.value().unique_id();
}
};
template <>
struct hash<cinn::adt::Variable> final {
std::size_t operator()(const cinn::adt::Variable& variable) const {
std::size_t hash_value =
variable >>
cinn::adt::match{
[](const cinn::adt::Iterator& iterator) {
return std::hash<cinn::adt::Iterator>()(iterator);
},
[](const cinn::adt::Index& index) {
return std::hash<cinn::adt::Index>()(index);
},
[](const cinn::adt::FakeOpPlaceHolder& fake_op_placeholder) {
return std::hash<cinn::adt::FakeOpPlaceHolder>()(
fake_op_placeholder);
}};
return cinn::adt::hash_combine(hash_value, variable.variant().index());
}
};
} // namespace std
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/generate_map_expr.h"
#include "paddle/cinn/adt/anchor_sd_equation_context.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/igroup.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/kgroup.h"
#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/adt/naive_bidirection_equation_generator.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
#include "paddle/cinn/adt/partition_op_stmts.h"
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
#include "paddle/cinn/adt/tree.h"
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/value.h"
#include "glog/logging.h"
PD_DECLARE_bool(cinn_enable_map_expr);
namespace cinn::adt {
template <>
struct TreeMerger<Stmt> {
using TreeT = Stmt;
using tree_type = TreeT;
using inner_type = typename TreeTrait<TreeT>::inner_type;
using leaf_type = typename TreeTrait<TreeT>::leaf_type;
using inner_data_type = typename inner_type::value_type;
std::function<inner_data_type(const leaf_type&)> GetInnerDataForLeaf;
inner_type MakeInnerNode(const inner_data_type& inner_data,
const List<TreeT>& children) const {
return MapStmt<Stmt>{inner_data, children};
}
using MergeResult = std::tuple<tCommon<inner_data_type>,
tLhsRemainder<inner_data_type>,
tRhsRemainder<inner_data_type>>;
MergeResult MergeInnerValue(const inner_data_type& lhs,
const inner_data_type& rhs) const {
inner_data_type common{};
inner_data_type lhs_remainder{};
inner_data_type rhs_remainder{};
int min_size = std::min(lhs->size(), rhs->size());
int idx = 0;
for (; idx < min_size; ++idx) {
if (lhs->at(idx) == rhs->at(idx)) {
common->emplace_back(lhs->at(idx));
} else {
break;
}
}
for (int lhs_idx = idx; lhs_idx < lhs->size(); ++lhs_idx) {
lhs_remainder->emplace_back(lhs->at(lhs_idx));
}
for (int rhs_idx = idx; rhs_idx < rhs->size(); ++rhs_idx) {
rhs_remainder->emplace_back(rhs->at(rhs_idx));
}
return MergeResult{common, lhs_remainder, rhs_remainder};
}
};
namespace {
using LoopDescriptor4IterVarT = std::function<LoopDescriptor(const Iterator&)>;
using AnchorTensor = Variable;
using FakeOpPlaceHolders = List<FakeOpPlaceHolder>;
Op MakeOp(const ::pir::Operation* op) { return {op}; }
template <typename DoEachT>
void VisitEachInputTensor(const ::pir::Operation* op, const DoEachT& DoEach) {
for (std::size_t i = 0; i < op->num_operands(); ++i) {
DoEach(op->operand_source(i));
}
}
List<Arg> MakeOpStmtInputList(const ::pir::Operation* op) {
List<Arg> ret{};
VisitEachInputTensor(op, [&](const ::pir::Value& tensor) {
ret->emplace_back(adapter::Tensor{tensor});
});
return ret;
}
template <typename DoEachT>
void VisitEachOutputTensor(const ::pir::Operation* op, const DoEachT& DoEach) {
for (std::size_t i = 0; i < op->num_results(); ++i) {
DoEach(const_cast<::pir::Operation*>(op)->result(i));
}
}
List<Arg> MakeOpStmtOutputList(const ::pir::Operation* op) {
List<Arg> ret{};
VisitEachOutputTensor(op, [&](const ::pir::Value& tensor) {
ret->emplace_back(adapter::Tensor{tensor});
});
return ret;
}
template <typename DoEachT>
void VisitEachOpStmt(const std::shared_ptr<hlir::framework::pir::Group>& group,
const DoEachT& DoEach) {
for (const auto* op : group->CollectOps()) {
DoEach(
OpStmt{MakeOp(op), MakeOpStmtInputList(op), MakeOpStmtOutputList(op)});
}
}
hlir::framework::OpPatternKind GetOpPatternKind(const ::pir::Operation* node) {
return hlir::framework::pir::CompatibleInfo::OpKind(*node);
}
bool CollectRewritedReductionOpStmts(const OpStmt& op_stmt, List<OpStmt>* ret) {
const auto& [op, inputs, outputs] = op_stmt.tuple();
CHECK(op.Has<const ::pir::Operation*>());
if (GetOpPatternKind(op.Get<const ::pir::Operation*>()) ==
hlir::framework::OpPatternKind::kReduction) {
tReduceInit<const ::pir::Operation*> init_op{
op.Get<const ::pir::Operation*>()};
(*ret)->emplace_back(OpStmt{init_op, List<Arg>{}, outputs});
tReduceAcc<const pir::Operation*> acc_op{op.Get<const ::pir::Operation*>()};
(*ret)->emplace_back(OpStmt{acc_op, inputs, outputs});
return true;
} else {
return false;
}
}
void CollectRewritedOpStmts(const OpStmt& op_stmt, List<OpStmt>* ret) {
if (CollectRewritedReductionOpStmts(op_stmt, ret)) {
return;
}
(*ret)->emplace_back(op_stmt);
}
List<OpStmt> MakeOpStmts(
const std::shared_ptr<hlir::framework::pir::Group>& group) {
List<OpStmt> ret{};
VisitEachOpStmt(group, [&](const auto& op_stmt) {
CollectRewritedOpStmts(op_stmt, &ret);
});
return ret;
}
template <typename DoEachT>
void PartitionIGroupOpStmts(const List<OpStmt>& op_stmts,
const DoEachT& DoEach) {
const auto& EquationCtx4OpStmt =
config::GenerateContext4LocalOpStmt(op_stmts);
auto direction_equation_generator =
std::make_shared<NaiveBidirectionEquationGenerator>(op_stmts,
EquationCtx4OpStmt);
const auto& igroup_specs = PartitionOpStmts(
EquationCtx4OpStmt, op_stmts, direction_equation_generator);
for (const auto& igroup_spec : igroup_specs) {
DoEach(igroup_spec);
}
}
std::shared_ptr<IGroup> MakeIGroup(const AnchorGroup& igroup_spec) {
std::shared_ptr<const EquationFunctionConstantsProvider> constants_provider{
new NaiveEquationFunctionConstantsProvider{
igroup_spec.op_stmts, igroup_spec.EquationCtx4OpStmt}};
std::shared_ptr<DirectionEquationGenerator> direction_equation_generator{
new NaiveBidirectionEquationGenerator{igroup_spec.op_stmts,
igroup_spec.EquationCtx4OpStmt}};
CheckEquationSolvable(
igroup_spec, constants_provider, direction_equation_generator);
return std::make_shared<IGroup>(igroup_spec.op_stmts,
igroup_spec.anchor_index,
igroup_spec.EquationCtx4OpStmt,
constants_provider);
}
std::vector<std::shared_ptr<IGroup>> GenerateIGroups(
const std::shared_ptr<hlir::framework::pir::Group>& group) {
std::vector<std::shared_ptr<IGroup>> ret{};
List<OpStmt> op_stmts = MakeOpStmts(group);
CHECK(!op_stmts->empty());
PartitionIGroupOpStmts(op_stmts, [&](const auto& igroup_spec) {
ret.push_back(MakeIGroup(igroup_spec));
});
return ret;
}
std::shared_ptr<KGroup> GenerateKGroups(
const std::shared_ptr<hlir::framework::pir::Group>& group,
const std::vector<std::shared_ptr<IGroup>>& igroups) {
CHECK_EQ(igroups.size(), 1);
return std::make_shared<KGroup>(group, igroups);
}
GraphView GenerateSdEquationGraphView(const std::shared_ptr<IGroup>& igroup,
const ScheduleMesh& sched_mesh) {
config::AnchorSdEquationContext ctx{sched_mesh, igroup->anchor_index()};
igroup->set_anchor_sd_equation_ctx(ctx);
Equations equations = igroup->anchor_sd_equation_ctx().value().equations();
return Graph::New(equations)->GetGraphView();
}
using TensorIndexExpr = Value;
std::unordered_map<Variable, const Value> MakeSdIterator2Iterator(
const IGroup& igroup) {
std::unordered_map<Variable, const Value> ret{};
for (std::size_t i = 0; i < igroup.loop_iterators()->size(); ++i) {
CHECK(ret.emplace(igroup.loop_iterators()->at(i),
igroup.loop_iterators()->at(i))
.second);
}
return ret;
}
std::shared_ptr<IndexExprInferContext> SolveEquationsThenReturnCtx(
const std::shared_ptr<IGroup>& igroup, const ScheduleMesh& sched_mesh) {
const auto& sd_equation_graph_view =
GenerateSdEquationGraphView(igroup, sched_mesh);
GraphView igroup_view = igroup->GetDefaultGraphView();
GraphView merged_view = igroup_view.Merge(sd_equation_graph_view);
const auto& init_var2value = MakeSdIterator2Iterator(*igroup);
auto ctx = std::make_shared<IndexExprInferContext>(
init_var2value, igroup->constants_provider());
std::vector<Variable> starts{};
for (const auto& loop_iterator : *igroup->loop_iterators()) {
starts.emplace_back(loop_iterator);
}
SolveEquations(merged_view, starts, ctx.get());
return ctx;
}
std::function<TensorIndexExpr(const Tensor&)> MakeGetterTensorIndexExpr(
const std::shared_ptr<IndexExprInferContext>& ctx,
const std::shared_ptr<IGroup>& igroup) {
return [ctx, igroup](const Tensor& tensor) {
// All indexes of same tensor have the same Value.
const auto& index = igroup->GetIndexes(tensor).at(0);
return ctx->GetValue(index);
};
}
TensorIteratorExpr4TensorT MakeGetterTensorIteratorExpr4Tensor(
const std::shared_ptr<IndexExprInferContext>& ctx,
const std::shared_ptr<IGroup>& igroup) {
return [ctx, igroup](const Tensor& tensor) -> List<TensorIteratorExpr> {
const auto& iterators = igroup->GetTensorIterators(tensor);
List<TensorIteratorExpr> ret{};
for (const auto& iterator : *iterators) {
ret->emplace_back(ctx->GetValue(iterator));
}
return ret;
};
}
LoopDescriptor4IterVarT MakeGetterLoopDescriptor4IterVar(
const LoopIterators& loop_iters, const LoopDescriptors& sd) {
CHECK_EQ(loop_iters->size(), sd->size());
using Cache = std::unordered_map<Iterator, LoopDescriptor>;
const auto& sd_iter2sd = std::make_shared<Cache>();
for (std::size_t i = 0; i < loop_iters->size(); ++i) {
CHECK(sd_iter2sd->emplace(loop_iters->at(i), sd->at(i)).second);
}
return [sd_iter2sd](const auto& sd_iter) { return sd_iter2sd->at(sd_iter); };
}
TreeMerger<Stmt> MakeTreeMerger(const MapIr& map_ir) {
using Cache = std::unordered_map<OpStmt, LoopIterators>;
auto cache = std::make_shared<Cache>();
for (const auto& op_stmt : *(map_ir.op_stmts())) {
CHECK(cache->emplace(op_stmt, map_ir.loop_iterators()).second);
}
TreeMerger<Stmt> tree_merger{};
tree_merger.GetInnerDataForLeaf =
([=](const OpStmt& op_stmt) -> LoopIterators {
return cache->at(op_stmt);
});
return tree_merger;
}
MapStmt<Stmt> MakeMapStmt(const MapIrList& map_irs) {
List<Stmt> stmts{};
for (const auto& map_ir : *map_irs) {
const TreeMerger<Stmt>& tree_merger = MakeTreeMerger(map_ir);
MergeTrees(tree_merger, &stmts, map_ir.op_stmts());
}
CHECK_EQ(stmts->size(), 1);
CHECK(stmts->at(0).Has<MapStmt<Stmt>>());
return stmts->at(0).Get<MapStmt<Stmt>>();
}
Tensor GetAnchorTensor(const std::shared_ptr<IGroup>& igroup) {
return igroup->anchor_tensor();
}
template <typename DoEachT>
void VisitInputTensor(const hlir::framework::pir::Group& group,
const DoEachT& DoEach) {
for (const ::pir::Value& node_data : group.GetInputOpValues()) {
DoEach(node_data);
}
}
template <typename DoEachT>
void VisitOutputTensor(const hlir::framework::pir::Group& group,
const DoEachT& DoEach) {
for (const ::pir::Value& node_data : group.GetOutputOpValues()) {
DoEach(node_data);
}
}
List<Tensor> MakeInputTensors(const std::shared_ptr<KGroup>& kgroup) {
List<Tensor> ret{};
VisitInputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) {
ret->emplace_back(adapter::Tensor{node_data});
});
return ret;
}
List<Tensor> MakeOutputTensors(const std::shared_ptr<KGroup>& kgroup) {
List<Tensor> ret{};
VisitOutputTensor(*kgroup->cinn_group(), [&](const ::pir::Value& node_data) {
ret->emplace_back(adapter::Tensor{node_data});
});
return ret;
}
AnchoredMapStmt GenerateAnchoredMapStmt(
const std::shared_ptr<IGroup>& igroup,
const LoopIterators& loop_iters,
const ScheduleMesh& sched_mesh,
const LoopDescriptors& sd,
const TensorIndexExpr4TensorT& TensorIndexExpr4Tensor,
const TensorIteratorExpr4TensorT& TensorIteratorExpr4Tensor) {
const auto& LoopDescriptor4IterVar =
MakeGetterLoopDescriptor4IterVar(loop_iters, sd);
const auto& map_irs = GenerateMapIrListForLoopFuse(
igroup->op_stmts(), loop_iters, TensorIndexExpr4Tensor);
return AnchoredMapStmt{MakeMapStmt(map_irs),
sched_mesh,
GetAnchorTensor(igroup),
TensorIndexExpr4Tensor,
TensorIteratorExpr4Tensor,
LoopDescriptor4IterVar};
}
AnchoredMapStmt GenerateAnchoredMapStmt(const std::shared_ptr<IGroup>& igroup) {
const auto& [sched_mesh, loop_types] =
CreateOptimizedScheduleMesh(igroup->anchor_schedule_dims());
const auto& sd = CreateScheduleDescriptor(sched_mesh, loop_types);
const auto& ctx = SolveEquationsThenReturnCtx(igroup, sched_mesh);
const auto& TensorIndexExpr4Tensor = MakeGetterTensorIndexExpr(ctx, igroup);
const auto& TensorIteratorExpr4Tensor =
MakeGetterTensorIteratorExpr4Tensor(ctx, igroup);
const auto& schedule_iters = igroup->loop_iterators();
return GenerateAnchoredMapStmt(igroup,
schedule_iters,
sched_mesh,
sd,
TensorIndexExpr4Tensor,
TensorIteratorExpr4Tensor);
}
List<AnchoredMapStmt> MakeAnchoredMapStmts(
const std::shared_ptr<KGroup>& kgroup) {
List<AnchoredMapStmt> ret{};
for (const auto& igroup : kgroup->igroups()) {
ret->emplace_back(GenerateAnchoredMapStmt(igroup));
}
return ret;
}
MapExpr GenerateMapExpr(const std::shared_ptr<KGroup>& kgroup) {
// MapExpr = Kernel;
// Kernel = ([AnchoredMapStmt], In [Tensor], Out [Tensor])
return MapExpr{MakeAnchoredMapStmts(kgroup),
MakeInputTensors(kgroup),
MakeOutputTensors(kgroup)};
}
} // namespace
MapExpr GenerateMapExpr(
const std::shared_ptr<hlir::framework::pir::Group>& group) {
const auto& igroups = GenerateIGroups(group);
const auto& kgroup = GenerateKGroups(group, igroups);
return GenerateMapExpr(kgroup);
}
void TryGenerateMapExprFromGraph(
const hlir::framework::pir::GroupList& groups) {
if (!FLAGS_cinn_enable_map_expr) {
return;
}
for (const auto& fusion_group : groups) {
const auto& map_expr = GenerateMapExpr(fusion_group);
VLOG(1) << ToTxtString(map_expr, fusion_group->group_id);
fusion_group->set_map_expr_ctx(std::make_shared<MapExprCtx>(map_expr));
}
}
void TryGenerateMapExprFromGroup(
const std::shared_ptr<hlir::framework::pir::Group>& fusion_group) {
if (!FLAGS_cinn_enable_map_expr) {
return;
}
const auto& map_expr = GenerateMapExpr(fusion_group);
VLOG(1) << ToTxtString(map_expr, fusion_group->group_id);
fusion_group->set_map_expr_ctx(std::make_shared<MapExprCtx>(map_expr));
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/cinn/adt/m_expr.h"
namespace cinn::hlir::framework::pir {
struct Group;
using GroupList = std::vector<std::shared_ptr<Group>>;
} // namespace cinn::hlir::framework::pir
namespace cinn::adt {
MapExpr GenerateMapExpr(
const std::shared_ptr<hlir::framework::pir::Group>& group);
void TryGenerateMapExprFromGraph(const hlir::framework::pir::GroupList& groups);
void TryGenerateMapExprFromGroup(
const std::shared_ptr<hlir::framework::pir::Group>& fusion_group);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/get_sub_reshape_dim_ranges.h"
#include "paddle/cinn/adt/equation_constant.h"
namespace cinn::adt {
namespace {
std::int64_t GetNumel(const List<Constant>& constants) {
std::int64_t ret = 1;
for (const auto& constant : *constants) {
ret *= constant.Get<std::int64_t>();
}
return ret;
}
} // namespace
std::optional<std::tuple<std::vector<std::pair<int, int>>,
std::vector<std::pair<int, int>>>>
GetSubReshapeDimRanges(const List<Constant>& lhs_dims,
const List<Constant>& rhs_dims) {
if (GetNumel(lhs_dims) != GetNumel(rhs_dims)) {
return std::nullopt;
}
CHECK(!lhs_dims->empty());
CHECK(!rhs_dims->empty());
std::vector<std::pair<int, int>> lhs_ranges{};
std::vector<std::pair<int, int>> rhs_ranges{};
int lhs_start = 0;
int rhs_start = 0;
int lhs_end = 0;
int rhs_end = 0;
const auto GetProduct = [&](const List<Constant>& dims,
std::size_t end) -> std::int64_t {
end = (end > dims->size() ? dims->size() : end);
std::int64_t ret = 1;
for (std::size_t i = 0; i < end; ++i) {
CHECK(dims->at(i).Has<std::int64_t>());
ret *= dims->at(i).Get<std::int64_t>();
}
return ret;
};
const auto LhsAcc = [&]() -> std::int64_t {
return GetProduct(lhs_dims, lhs_end);
};
const auto RhsAcc = [&]() -> std::int64_t {
return GetProduct(rhs_dims, rhs_end);
};
while (lhs_end < lhs_dims->size() || rhs_end < rhs_dims->size()) {
if (lhs_start == lhs_end) {
lhs_end++;
}
if (rhs_start == rhs_end) {
rhs_end++;
}
if (LhsAcc() == RhsAcc()) {
lhs_ranges.emplace_back(std::make_pair(lhs_start, lhs_end));
rhs_ranges.emplace_back(std::make_pair(rhs_start, rhs_end));
lhs_start = lhs_end;
rhs_start = rhs_end;
} else if (LhsAcc() < RhsAcc()) {
lhs_end++;
} else if (LhsAcc() > RhsAcc()) {
rhs_end++;
} else {
LOG(FATAL) << "Dead code";
}
}
CHECK(lhs_end == lhs_dims->size() && rhs_end == rhs_dims->size());
if (lhs_start < lhs_end && rhs_start < rhs_end) {
lhs_ranges.emplace_back(std::make_pair(lhs_start, lhs_end));
rhs_ranges.emplace_back(std::make_pair(rhs_start, rhs_end));
}
return std::make_tuple(lhs_ranges, rhs_ranges);
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <optional>
#include <vector>
#include "paddle/cinn/adt/adt.h"
namespace cinn::adt {
class Constant;
std::optional<std::tuple<std::vector<std::pair<int, int>>,
std::vector<std::pair<int, int>>>>
GetSubReshapeDimRanges(const List<Constant>& lhs_dims,
const List<Constant>& rhs_dims);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/igroup.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
namespace cinn::adt {
namespace {
std::shared_ptr<IndexExprInferContext> MakeIndexExprInferContext(
const IGroup& igroup) {
std::unordered_map<Variable, const Value> anchor_iterator2value{};
const auto& anchor_iterators = igroup.GetAnchorIterators();
for (std::size_t i = 0; i < anchor_iterators->size(); ++i) {
CHECK(anchor_iterator2value
.emplace(anchor_iterators->at(i), anchor_iterators->at(i))
.second);
}
return std::make_shared<IndexExprInferContext>(anchor_iterator2value,
igroup.constants_provider());
}
std::function<Value(const Iterator&)> MakeGetterValue4Iterator(
const IGroup* igroup) {
GraphView igroup_view = igroup->GetDefaultGraphView();
const auto& ctx = MakeIndexExprInferContext(*igroup);
std::vector<Variable> starts{};
for (const auto& anchor_iterator : *igroup->GetAnchorIterators()) {
starts.emplace_back(anchor_iterator);
}
SolveEquations(igroup_view, starts, ctx.get());
return [ctx](const Iterator& iterator) { return ctx->GetValue(iterator); };
}
List<LoopSize> MakeAnchorLoopSize(const Tensor& tensor) {
List<LoopSize> ret{};
CHECK(tensor.Has<adapter::Tensor>());
for (const auto& dim : tensor.Get<adapter::Tensor>().GetShape()) {
ret->emplace_back(dim);
}
return ret;
}
} // namespace
void IGroup::InitAnchorScheduleDims() {
const auto& Value4Iterator = MakeGetterValue4Iterator(this);
const auto& loop_size = MakeAnchorLoopSize(this->anchor_tensor());
anchor_schedule_dims_ = MakeAnchorScheduleDims(
*this, Value4Iterator, loop_size, this->GetAnchorIterators());
}
List<Iterator> IGroup::GetIndexIterators(const Index& index) const {
List<Iterator> ret{};
for (const auto& op_stmt : *op_stmts_) {
const auto& ctx = EquationCtx4OpStmt_(op_stmt);
const OpArgPos& arg_pos = ctx->GetOpArgPos(index);
if (arg_pos.Has<tIn<std::size_t>>()) {
return ctx->GetInIteratorTuple(arg_pos.Get<tIn<std::size_t>>().value());
} else if (arg_pos.Has<tOut<std::size_t>>()) {
return ctx->GetOutIteratorTuple(arg_pos.Get<tOut<std::size_t>>().value());
} else if (arg_pos.Has<Undefined>()) {
// do nothing
} else {
LOG(FATAL) << "Dead code";
}
}
LOG(FATAL) << "Can not find anchor iterators";
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <optional>
#include <vector>
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/anchor_sd_equation_context.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_function_constants_provider.h"
#include "paddle/cinn/adt/equation_graph.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/m_ir.h"
#include "paddle/cinn/adt/naive_bidirection_equation_generator.h"
#include "paddle/cinn/adt/naive_equation_function_constants_provider.h"
#include "paddle/cinn/adt/naive_op_equation_context.h"
#include "paddle/cinn/adt/partition_op_stmts.h"
#include "paddle/cinn/adt/schedule_dim.h"
namespace cinn::adt {
using AnchorIndex = Index;
using EquationCtx4OpStmtT =
std::function<std::shared_ptr<config::NaiveOpEquationContext>(
const OpStmt&)>;
/**
* IGroup = Inline Group.
* Each IGroup must have an AnchorTensor as a representative.
* Each index of the AnchorTensor can be mapped to the unique index of any other
* Tensor in the IGroup. IGroup solves the problem of cross-thread data usage.
* Please note that syncthreads needs to be called in time.
*/
class IGroup final {
public:
IGroup(const IGroup&) = delete;
IGroup(IGroup&&) = delete;
explicit IGroup(
const List<OpStmt>& op_stmts,
const AnchorIndex& anchor_index,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider)
: op_stmts_(op_stmts),
anchor_index_(anchor_index),
EquationCtx4OpStmt_(EquationCtx4OpStmt),
constants_provider_(constants_provider) {
GenerateIndex2Tensor(
op_stmts, EquationCtx4OpStmt, &index2tensor_, &tensor2indexes_);
InitAnchorScheduleDims();
}
const List<OpStmt>& op_stmts() const { return op_stmts_; }
const AnchorIndex& anchor_index() const { return anchor_index_; }
const Tensor& anchor_tensor() const { return GetTensor(anchor_index()); }
const List<ScheduleDim>& anchor_schedule_dims() const {
return anchor_schedule_dims_;
}
const EquationCtx4OpStmtT& EquationCtx4OpStmt() const {
return EquationCtx4OpStmt_;
}
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider() const {
return constants_provider_;
}
GraphView GetDefaultGraphView() const {
auto direction_equation_generator =
std::make_shared<NaiveBidirectionEquationGenerator>(
op_stmts_, EquationCtx4OpStmt_);
return MakeGlobalEquationGraphViewForPartition(
EquationCtx4OpStmt_, op_stmts_, direction_equation_generator);
}
const Tensor& GetTensor(const Index& index) const {
return index2tensor_.at(index);
}
const std::vector<Index>& GetIndexes(const Tensor& tensor) const {
return tensor2indexes_.at(tensor);
}
const std::optional<config::AnchorSdEquationContext>& anchor_sd_equation_ctx()
const {
return anchor_sd_equation_ctx_;
}
void set_anchor_sd_equation_ctx(const config::AnchorSdEquationContext& ctx) {
anchor_sd_equation_ctx_ = ctx;
auto* mut_constants_provider =
const_cast<EquationFunctionConstantsProvider*>(
constants_provider_.get());
for (const auto& [dim, dim_value] : ctx.dim2constant()) {
CHECK(mut_constants_provider->AddDim(dim, dim_value));
}
}
const List<Iterator>& loop_iterators() const {
CHECK(anchor_sd_equation_ctx_.has_value());
return anchor_sd_equation_ctx_.value().sd_iterators();
}
List<Iterator> GetIndexIterators(const Index& index) const;
List<Iterator> GetTensorIterators(const Tensor& tensor) const {
return GetIndexIterators(GetIndexes(tensor).at(0));
}
List<Iterator> GetAnchorIterators() const {
return GetIndexIterators(anchor_index_);
}
private:
void InitAnchorScheduleDims();
static void GenerateIndex2Tensor(
const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt,
std::unordered_map<Index, Tensor>* index2tensor,
std::unordered_map<Tensor, std::vector<Index>>* tensor2indexes) {
for (const auto& op_stmt : *op_stmts) {
const auto& ctx = EquationCtx4OpStmt(op_stmt);
const auto& [op, op_inputs, op_outputs] = op_stmt.tuple();
for (std::size_t idx = 0; idx < op_inputs.value()->size(); ++idx) {
const auto& index = ctx->GetInIndex(idx);
const auto& tensor = op_inputs.value()->at(idx);
CHECK(index2tensor->emplace(index, tensor).second);
(*tensor2indexes)[tensor].emplace_back(index);
}
for (std::size_t idx = 0; idx < op_outputs.value()->size(); ++idx) {
const auto& index = ctx->GetOutIndex(idx);
const auto& tensor = op_outputs.value()->at(idx);
CHECK(index2tensor->emplace(index, tensor).second);
(*tensor2indexes)[tensor].emplace_back(index);
}
}
}
List<OpStmt> op_stmts_;
AnchorIndex anchor_index_;
EquationCtx4OpStmtT EquationCtx4OpStmt_;
std::unordered_map<Index, Tensor> index2tensor_;
std::unordered_map<Tensor, std::vector<Index>> tensor2indexes_;
std::optional<config::AnchorSdEquationContext> anchor_sd_equation_ctx_;
std::shared_ptr<const EquationFunctionConstantsProvider> constants_provider_;
List<ScheduleDim> anchor_schedule_dims_;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/equation_function_constants_provider.h"
namespace cinn::adt {
Constant IndexExprInferContext::GetDimSize(const Dim& dim) const {
return constants_provider_->GetDimSize(dim);
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/m_expr.h"
namespace cinn::adt {
class EquationFunctionConstantsProvider;
class IndexExprInferContext final {
public:
IndexExprInferContext(const IndexExprInferContext&) = delete;
IndexExprInferContext(IndexExprInferContext&&) = delete;
explicit IndexExprInferContext(
const std::unordered_map<Variable, const Value>& init_variable2value,
const std::shared_ptr<const EquationFunctionConstantsProvider>&
constants_provider)
: variable2value_(init_variable2value),
constants_provider_(constants_provider) {}
const Value& GetValue(const Variable& variable) const {
return variable2value_.at(variable);
}
auto SetValue(const Variable& variable, const Value& value) {
return variable2value_.emplace(variable, value);
}
bool HasValue(const Variable& variable) const {
return variable2value_.count(variable) > 0;
}
Constant GetDimSize(const Dim& dim) const;
private:
std::unordered_map<Variable, const Value> variable2value_;
std::shared_ptr<const EquationFunctionConstantsProvider> constants_provider_;
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/inline_translator_trait.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/tree.h"
namespace cinn::adt {
template <template <typename> class MapT,
template <typename>
class OpCallT,
typename TensorT>
struct InlineTranslator final {
using SrcLeaf = Store<TensorT, OpCallT<Load<TensorT>>>;
using OpExpr = Tree<OpCallT, Load<TensorT>>;
using DstLeaf = Store<TensorT, OpExpr>;
using SrcTree = Tree<MapT, SrcLeaf>;
using DstTree = Tree<MapT, DstLeaf>;
static DstTree Call(const SrcTree& src_tree) {
CHECK((src_tree.template Has<MapT<SrcTree>>()));
const MapT<DstTree> dst_tree =
CallMap(src_tree.template Get<MapT<SrcTree>>());
return DstTree{dst_tree};
}
private:
static MapT<DstTree> CallMap(const MapT<SrcTree>& src_map) {
const List<SrcTree> src_children =
InlineTranslatorTrait<MapT>::GetTreeInnerNodeChildren(src_map);
const List<DstTree> dst_children = CallList(src_children);
return InlineTranslatorTrait<MapT>::ConvertMap(src_map, dst_children);
}
static List<DstTree> CallList(const List<SrcTree>& src_children) {
List<DstTree> ret{};
VisitEachContiguousSegment(
src_children, [&](int start, int end, bool is_leaf) {
if (!is_leaf) {
for (int i = start; i < end; ++i) {
ret->emplace_back(Call(src_children->at(i)));
}
} else {
const auto& converted = TranslateContiguousLeaves(
std::next(src_children->begin(), start),
std::next(src_children->begin(), end));
ret->insert(ret->end(), converted->begin(), converted->end());
}
});
return ret;
}
struct ConsumerPos {
int leaf_index;
int arg_index;
};
// using DstLeaf = Store<TensorT, OpExpr>;
static DstLeaf UpdateConsumerArg(const DstLeaf& consumer,
int arg_index,
const DstLeaf& producer) {
const auto& [consumer_tensor, consumer_tree] = consumer.tuple();
CheckConsumerPosIsLoadTensor(consumer, arg_index);
const auto& op_call = consumer_tree.template Get<OpCallT<OpExpr>>();
const auto& op_call_children =
InlineTranslatorTrait<OpCallT>::GetTreeInnerNodeChildren(op_call);
const auto& ret_op_call_children =
UpdateConsumerArg(op_call_children, arg_index, producer);
const auto& ret_op_call = InlineTranslatorTrait<OpCallT>::ConvertMap(
op_call, ret_op_call_children);
OpExpr ret_op_call_tree = ret_op_call;
return DstLeaf{consumer_tensor, ret_op_call_tree};
}
static List<OpExpr> UpdateConsumerArg(const List<OpExpr>& op_call_children,
int arg_index,
const DstLeaf& producer) {
const auto& [producer_tensor, producer_tree] = producer.tuple();
const auto& arg = op_call_children->at(arg_index);
const auto& arg_leaf = arg.template Get<Load<TensorT>>();
const auto& [arg_tensor] = arg_leaf.tuple();
CHECK(producer_tensor == arg_tensor);
List<OpExpr> ret{};
ret->assign(op_call_children->begin(), op_call_children->end());
ret->at(arg_index) = producer_tree;
return ret;
}
// using DstLeaf = Store<TensorT, OpExpr>;
static void CheckConsumerPosIsLoadTensor(const DstLeaf& consumer,
int arg_index) {
const auto& [tensor, consumer_tree] = consumer.tuple();
CHECK((consumer_tree.template Has<OpCallT<OpExpr>>()));
const auto& op_call = consumer_tree.template Get<OpCallT<OpExpr>>();
const auto& op_call_children =
InlineTranslatorTrait<OpCallT>::GetTreeInnerNodeChildren(op_call);
const auto& op_call_child = op_call_children->at(arg_index);
CHECK((op_call_child.template Has<Load<TensorT>>()));
}
template <typename DoEachT>
static void VisitEachArg(const SrcTree& tree, const DoEachT& DoEach) {
const auto& [_, op_call] = tree.template Get<SrcLeaf>().tuple();
const auto& args =
InlineTranslatorTrait<OpCallT>::GetTreeInnerNodeChildren(op_call);
for (int i = 0; i < args->size(); ++i) {
const auto& [tensor] = args->at(i).tuple();
DoEach(tensor, i);
}
}
// using SrcLeaf = Store<TensorT, OpCallT<Load<TensorT>>>;
template <typename SrcTreeIterT>
static std::vector<std::vector<ConsumerPos>> MakeProducerIndex2ConsumerPos(
SrcTreeIterT begin, SrcTreeIterT end) {
std::vector<std::vector<ConsumerPos>> producer_index2consumer_positions(
end - begin);
for (SrcTreeIterT producer = begin; producer != end; ++producer) {
const auto& [producer_tensor, _] =
(*producer).template Get<SrcLeaf>().tuple();
for (SrcTreeIterT consumer = std::next(producer); consumer != end;
++consumer) {
VisitEachArg(*consumer, [&](const TensorT arg_tensor, int arg_idx) {
if (arg_tensor == producer_tensor) {
auto* vec = &producer_index2consumer_positions.at(producer - begin);
vec->push_back(ConsumerPos{.leaf_index = consumer - begin,
.arg_index = arg_idx});
}
});
}
}
return producer_index2consumer_positions;
}
template <typename SrcTreeIterT>
static List<DstTree> TranslateContiguousLeaves(SrcTreeIterT begin,
SrcTreeIterT end) {
int size = end - begin;
const auto producer_idx2consumer_pos =
MakeProducerIndex2ConsumerPos(begin, end);
const auto& GetConsumerPos4ProducerIndex =
[&](int index) -> std::vector<ConsumerPos> {
return producer_idx2consumer_pos.at(index);
};
std::unordered_map<int, DstLeaf> index2dst_leaf{};
// Init dst leaves
for (int i = 0; i < size; ++i) {
CHECK(index2dst_leaf.emplace(i, NaiveTranslateLeaf(*std::next(begin, i)))
.second);
}
// Inline dst leaves
for (int producer_i = 0; producer_i < size; ++producer_i) {
const auto& consumer_positions = GetConsumerPos4ProducerIndex(producer_i);
if (consumer_positions.empty()) {
// Do nothing
} else {
DstLeaf producer = index2dst_leaf.at(producer_i);
for (const auto& consumer_pos : consumer_positions) {
DstLeaf consumer = index2dst_leaf.at(consumer_pos.leaf_index);
index2dst_leaf.at(consumer_pos.leaf_index) =
UpdateConsumerArg(consumer, consumer_pos.arg_index, producer);
}
index2dst_leaf.erase(producer_i);
}
}
// Collect inlined leaves
List<DstTree> ret{};
for (int i = 0; i < size; ++i) {
const auto& iter = index2dst_leaf.find(i);
if (iter != index2dst_leaf.end()) {
ret->emplace_back(iter->second);
}
}
return ret;
}
// using SrcLeaf = Store<TensorT, OpCallT<Load<TensorT>>>;
// using DstLeaf = Store<TensorT, OpExpr>;
static DstLeaf NaiveTranslateLeaf(const SrcTree& src_tree) {
CHECK(src_tree.template Has<SrcLeaf>());
const auto& [tensor, op_call] = src_tree.template Get<SrcLeaf>().tuple();
const List<Load<TensorT>>& src_loads =
InlineTranslatorTrait<OpCallT>::GetTreeInnerNodeChildren(op_call);
List<OpExpr> dst_loads{};
for (const auto& src_load : *src_loads) {
dst_loads->emplace_back(src_load);
}
OpCallT<OpExpr> dst_op_call =
InlineTranslatorTrait<OpCallT>::ConvertMap(op_call, dst_loads);
OpExpr dst_op_call_tree = dst_op_call;
return DstLeaf{tensor, dst_op_call_tree};
}
template <typename DoEachT /*void(&)(int start, int end, bool is_leaf)*/>
static void VisitEachContiguousSegment(const List<SrcTree>& src_children,
const DoEachT& DoEach) {
std::vector<int> child_index2is_leaf(src_children->size(), 0);
for (int i = 0; i < src_children->size(); ++i) {
child_index2is_leaf.at(i) = src_children->at(i).template Has<SrcLeaf>();
}
int start = 0;
for (int i = 1; i < child_index2is_leaf.size(); ++i) {
if (child_index2is_leaf.at(i - 1) != child_index2is_leaf.at(i)) {
DoEach(start, i, child_index2is_leaf.at(i - 1));
start = i;
} else {
// Do nothing
}
}
if (start != child_index2is_leaf.size()) {
DoEach(start, child_index2is_leaf.size(), child_index2is_leaf.back());
}
}
};
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/inline_translator.h"
#include <string>
#include "gtest/gtest.h"
namespace cinn::adt {
template <>
struct InlineTranslatorTrait<List> final {
template <typename T>
static List<T> GetTreeInnerNodeChildren(const List<T>& list) {
return list;
}
template <typename SrcTreeT, typename DstTreeT>
static List<DstTreeT> ConvertMap(const List<SrcTreeT>& src_map,
const List<DstTreeT>& dst_children) {
return dst_children;
}
};
namespace test {
using SrcLeaf = Store<std::string, List<Load<std::string>>>;
using DstLeaf = Store<std::string, Tree<List, Load<std::string>>>;
using SrcTree = Tree<List, SrcLeaf>;
using DstTree = Tree<List, DstLeaf>;
// (Tree List (Store string (List (Load string)))) ->
// (Tree List (Store string (Tree List (Load string))))
// Src:
// c = [a, b];
// d = [c, e];
// f = [d, c];
// Dst:
// f = [[[a, b], e], [a, b]];
TEST(InlineTranslator, Naive) {
List<Tree<List, SrcLeaf>> src_op_calls{};
src_op_calls->emplace_back(SrcLeaf{
"c",
List<Load<std::string>>{Load<std::string>{"a"}, Load<std::string>{"b"}}});
src_op_calls->emplace_back(SrcLeaf{
"d",
List<Load<std::string>>{Load<std::string>{"c"}, Load<std::string>{"e"}}});
src_op_calls->emplace_back(SrcLeaf{
"f",
List<Load<std::string>>{Load<std::string>{"d"}, Load<std::string>{"c"}}});
SrcTree src_tree{src_op_calls};
const DstTree& dst_tree =
InlineTranslator<List, List, std::string>::Call(src_tree);
ASSERT_TRUE((dst_tree.Has<List<Tree<List, DstLeaf>>>()));
const auto& dst_level0_leaves = dst_tree.Get<List<Tree<List, DstLeaf>>>();
ASSERT_EQ(dst_level0_leaves->size(), 1);
const auto& dst_level0_leaf = dst_level0_leaves->at(0);
ASSERT_TRUE((dst_level0_leaf.Has<DstLeaf>()));
const auto& [f, f_tree] = dst_level0_leaf.Get<DstLeaf>().tuple();
ASSERT_EQ(f, "f");
using NestedList = Tree<List, Load<std::string>>;
ASSERT_TRUE((f_tree.Has<List<NestedList>>()));
ASSERT_EQ((f_tree.Get<List<NestedList>>()->size()), 2);
{
// [[a, b], e]
NestedList d = f_tree.Get<List<NestedList>>()->at(0);
ASSERT_TRUE((d.Has<List<NestedList>>()));
ASSERT_EQ((d.Get<List<NestedList>>()->size()), 2);
{
// [a, b]
NestedList c = d.Get<List<NestedList>>()->at(0);
ASSERT_TRUE((c.Has<List<NestedList>>()));
ASSERT_EQ((c.Get<List<NestedList>>()->size()), 2);
{
NestedList a = c.Get<List<NestedList>>()->at(0);
ASSERT_TRUE((a.Has<Load<std::string>>()));
const auto& [a_string] = a.Get<Load<std::string>>().tuple();
ASSERT_EQ(a_string, "a");
}
{
NestedList b = c.Get<List<NestedList>>()->at(1);
ASSERT_TRUE((b.Has<Load<std::string>>()));
const auto& [b_string] = b.Get<Load<std::string>>().tuple();
ASSERT_EQ(b_string, "b");
}
}
{
NestedList e = d.Get<List<NestedList>>()->at(1);
ASSERT_TRUE((e.Has<Load<std::string>>()));
const auto& [e_string] = e.Get<Load<std::string>>().tuple();
ASSERT_EQ(e_string, "e");
}
}
{
// [a, b]
NestedList c = f_tree.Get<List<NestedList>>()->at(1);
ASSERT_TRUE((c.Has<List<NestedList>>()));
ASSERT_EQ((c.Get<List<NestedList>>()->size()), 2);
{
NestedList a = c.Get<List<NestedList>>()->at(0);
ASSERT_TRUE((a.Has<Load<std::string>>()));
const auto& [a_string] = a.Get<Load<std::string>>().tuple();
ASSERT_EQ(a_string, "a");
}
{
NestedList b = c.Get<List<NestedList>>()->at(1);
ASSERT_TRUE((b.Has<Load<std::string>>()));
const auto& [b_string] = b.Get<Load<std::string>>().tuple();
ASSERT_EQ(b_string, "b");
}
}
}
} // namespace test
} // namespace cinn::adt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment