Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/print_utils/print_constant.h"
#include "paddle/cinn/adt/print_utils/print_equations.h"
#include "paddle/cinn/adt/print_utils/print_loop_size.h"
#include "paddle/cinn/adt/print_utils/print_map_expr.h"
#include "paddle/cinn/adt/print_utils/print_schedule_descriptor.h"
#include "paddle/cinn/adt/print_utils/print_schedule_dim.h"
#include "paddle/cinn/adt/print_utils/print_schedule_mesh.h"
#include "paddle/cinn/adt/print_utils/print_value.h"
if(NOT CINN_ONLY)
core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
print_constant.cc
print_equations.cc
print_loop_size.cc
print_map_expr.cc
print_schedule_descriptor.cc
print_schedule_dim.cc
print_schedule_mesh.cc
print_value.cc)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_constant.h"
#include "paddle/cinn/adt/equation_constant.h"
namespace cinn::adt {
namespace {
struct ToTxtStringStruct {
std::string operator()(const std::int64_t constant) {
return std::to_string(constant);
}
std::string operator()(const tDim<UniqueId>& constant) {
std::size_t constant_unique_id = constant.value().unique_id();
return "dim_" + std::to_string(constant_unique_id);
}
std::string operator()(const List<Constant>& constants) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < constants->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(constants.Get(idx));
}
ret += "]";
return ret;
}
};
} // namespace
std::string ToTxtString(const Constant& constant) {
return std::visit(ToTxtStringStruct{}, constant.variant());
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace cinn::adt {
class Constant;
std::string ToTxtString(const Constant& constant);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_equations.h"
#include <sstream>
#include <string>
#include "paddle/cinn/adt/equation_function.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/pir/core/operation.h"
namespace cinn::adt {
namespace {
std::string ToTxtString(const tDim<UniqueId>& constant) {
std::size_t constant_unique_id = constant.value().unique_id();
return "dim_" + std::to_string(constant_unique_id);
}
std::string OpImpl(const ::pir::Operation* op) {
return hlir::framework::pir::CompatibleInfo::OpName(*op);
}
std::string OpImpl(const tReduceInit<const ::pir::Operation*>& op) {
return OpImpl(op.value()) + "_init";
}
std::string OpImpl(const tReduceAcc<const ::pir::Operation*>& op) {
return OpImpl(op.value()) + "_acc";
}
} // namespace
std::string ToTxtString(const Iterator& iterator) {
std::size_t iterator_unique_id = iterator.value().unique_id();
return "i_" + std::to_string(iterator_unique_id);
}
std::string ToTxtString(const Index& index) {
std::size_t index_unique_id = index.value().unique_id();
return "idx_" + std::to_string(index_unique_id);
}
std::string ToTxtString(const FakeOpPlaceHolder& op) {
std::size_t op_unique_id = op.value().unique_id();
return "op_" + std::to_string(op_unique_id);
}
std::string ToTxtString(const List<Index>& indexes) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < indexes->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(indexes.Get(idx));
}
ret += "]";
return ret;
}
std::string ToTxtString(const List<std::optional<Index>>& indexes) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < indexes->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
if (indexes->at(idx).has_value()) {
ret += ToTxtString(indexes.Get(idx).value());
}
}
ret += "]";
return ret;
}
std::string ToTxtString(const List<Iterator>& iterators) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < iterators->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(iterators.Get(idx));
}
ret += "]";
return ret;
}
std::string ToTxtString(const List<Dim>& dim_list) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < dim_list->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(dim_list.Get(idx));
}
ret += "]";
return ret;
}
std::string ToTxtString(const tInMsg<List<Index>>& in_msg_indexes) {
std::string ret;
const List<Index>& index_list = in_msg_indexes.value();
ret += ToTxtString(index_list);
return ret;
}
std::string ToTxtString(const tOutMsg<List<Index>>& out_msg_indexes) {
std::string ret;
const List<Index>& index_list = out_msg_indexes.value();
ret += ToTxtString(index_list);
return ret;
}
std::string ToTxtString(const std::vector<Index>& indexes) {
std::string ret;
ret += "vector(";
for (std::size_t idx = 0; idx < indexes.size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(indexes.at(idx));
}
ret += ")";
return ret;
}
std::string ToTxtString(const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt) {
std::string ret;
std::size_t count = 0;
for (const auto& op_stmt : *op_stmts) {
if (count++ != 0) {
ret += "\n";
}
const auto& [op, _0, _1] = op_stmt.tuple();
ret += std::visit([&](const auto& op_impl) { return OpImpl(op_impl); },
op.variant());
ret += ": \n";
const auto& ctx = EquationCtx4OpStmt(op_stmt);
ret += ToTxtString(ctx->equations(), "\n");
}
return ret;
}
namespace {
struct ToTxtStringStruct {
std::string operator()(
const Identity<tOut<Iterator>, tIn<Iterator>>& id) const {
std::string ret;
const auto& [out_iter, in_iter] = id.tuple();
ret += ToTxtString(out_iter.value()) + " = " + ToTxtString(in_iter.value());
return ret;
}
std::string operator()(const Identity<tOut<Index>, tIn<Index>>& id) const {
std::string ret;
const auto& [out_index, in_index] = id.tuple();
ret +=
ToTxtString(out_index.value()) + " = " + ToTxtString(in_index.value());
return ret;
}
std::string operator()(
const IndexDot<List<Dim>, tOut<Index>, tIn<List<Iterator>>>& dot) const {
std::string ret;
const auto& [dim_list, out_index_tag, in_iterator_list_tag] = dot.tuple();
const Index& out_index = out_index_tag.value();
const List<Iterator>& in_iterator_list = in_iterator_list_tag.value();
ret += ToTxtString(out_index) + " = IndexDot(" +
ToTxtString(in_iterator_list) + ")";
return ret;
}
std::string operator()(
const GetBroadcastedIterator<Dim, tOut<Iterator>, tIn<Iterator>>&
broadcast) const {
std::string ret;
const auto& [dim, out_iterator, in_iterator] = broadcast.tuple();
ret += ToTxtString(out_iterator.value()) + " = GetBroadcastedIterator(" +
ToTxtString(in_iterator.value()) + ", " + ToTxtString(dim) + ")";
return ret;
}
std::string operator()(
const IndexUnDot<List<Dim>, tOut<List<Iterator>>, tIn<Index>>& undot)
const {
std::string ret;
const auto& [dim_list, out_iterator_list_tag, in_index_tag] = undot.tuple();
const List<Iterator>& out_iterator_list = out_iterator_list_tag.value();
const Index& in_index = in_index_tag.value();
ret += ToTxtString(out_iterator_list) + " = IndexUnDot(" +
ToTxtString(in_index) + ")";
return ret;
}
std::string operator()(
const InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>& in_msg2out_msg) const {
std::string ret;
const auto& [out_op, out_indexs, in_indexs] = in_msg2out_msg.tuple();
const FakeOpPlaceHolder& op = out_op.value();
const auto& out_index_tuple = out_indexs.value();
const auto& in_index_tuple = in_indexs.value();
const auto& [out_msg_list_in, out_msg_list_out] = out_index_tuple.tuple();
const auto& [in_msg_list_in, in_msg_list_out] = in_index_tuple.tuple();
ret += ToTxtString(op) + ", ";
ret += "(" + ToTxtString(out_msg_list_in.value()) + ", " +
ToTxtString(out_msg_list_out.value()) + ") = InMsg2OutMsg(";
ret += ToTxtString(in_msg_list_in.value()) + ", " +
ToTxtString(in_msg_list_out.value()) + ")";
return ret;
}
std::string operator()(
const ConstantFunction<tOut<Iterator>, tIn<Index>>& constant) const {
std::string ret{};
return ret;
}
};
} // namespace
std::string ToTxtString(const Equation& equation) {
return std::visit(ToTxtStringStruct{}, equation.variant());
}
std::string ToTxtString(const Equations& equations,
const std::string& separator) {
std::stringstream ret;
std::size_t count = 0;
for (const auto& equation : *equations) {
if (count++ > 0) {
ret << separator;
}
ret << &equation << ": ";
ret << ToTxtString(equation);
}
return ret.str();
}
std::string ToTxtStringImpl(const Iterator& iterator) {
return ToTxtString(iterator);
}
std::string ToTxtStringImpl(const Index& index) { return ToTxtString(index); }
std::string ToTxtStringImpl(const FakeOpPlaceHolder& op) {
return ToTxtString(op);
}
std::string ToTxtString(const Variable& variable) {
return std::visit([&](const auto& impl) { return ToTxtStringImpl(impl); },
variable.variant());
}
std::string ToDotString(
const Equations& equations,
const std::optional<Variable>& start,
const std::unordered_set<Variable>& visited_variables,
const std::unordered_set<const void*>& visited_functions) {
std::stringstream ss;
const auto& GetFunctionUid = [&](const Equation& equation) {
std::stringstream ss;
ss << "f" << GetFunctionDataPtr(equation);
return ss.str();
};
ss << "digraph {\n";
const auto& FillFunctionColor = [&](const Equation& function) -> std::string {
if (visited_functions.count(GetFunctionDataPtr(function))) {
return ", style=filled, color=green";
} else {
return "";
}
};
std::unordered_set<Variable> variables{};
for (const auto& equation : *equations) {
const auto& [in_variables, out_variables] =
CollectInputAndOutputVariables(equation);
ss << GetFunctionUid(equation) << "["
<< "label=\"" << GetFunctionTypeName(equation) << "<"
<< GetFunctionDataPtr(equation) << ">"
<< "\"" << FillFunctionColor(equation) << "]\n";
for (const auto& in_variable : in_variables) {
ss << ToTxtString(in_variable) << " -> " << GetFunctionUid(equation)
<< ";\n";
variables.insert(in_variable);
}
for (const auto& out_variable : out_variables) {
ss << GetFunctionUid(equation) << " -> " << ToTxtString(out_variable)
<< ";\n";
variables.insert(out_variable);
}
}
const auto& GetColor = [&](const Variable& variable) {
if (start.has_value() && start.value() == variable) {
return "red";
} else {
return "green";
}
};
for (const auto& variable : variables) {
if (visited_variables.count(variable)) {
ss << ToTxtString(variable)
<< "[style=filled, color=" << GetColor(variable) << "];\n";
}
}
ss << "}\n";
return ss.str();
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <optional>
#include <vector>
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/partition_op_stmts.h"
namespace cinn::adt {
std::string ToTxtString(const Equation& equation);
std::string ToTxtString(const Equations& equations,
const std::string& separator = "\n");
std::string ToTxtString(const Iterator& iterator);
std::string ToTxtString(const Index& index);
std::string ToTxtString(const FakeOpPlaceHolder& op);
std::string ToTxtString(const List<Index>& indexes);
std::string ToTxtString(const List<std::optional<Index>>& indexes);
std::string ToTxtString(const List<Dim>& strides);
std::string ToTxtString(const List<Iterator>& iterators);
std::string ToTxtString(const tInMsg<List<Index>>& in_msg_indexes_);
std::string ToTxtString(const tOutMsg<List<Index>>& out_msg_indexes_);
std::string ToTxtString(const std::vector<Index>& indexes);
std::string ToTxtString(const List<OpStmt>& op_stmts,
const EquationCtx4OpStmtT& EquationCtx4OpStmt);
std::string ToDotString(
const Equations& equations,
const std::optional<Variable>& start,
const std::unordered_set<Variable>& visited_variables,
const std::unordered_set<const void*>& visited_functions);
std::string ToTxtString(const Variable& variable);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_loop_size.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
namespace cinn::adt {
std::string ToTxtString(const LoopSize& loop_size) {
return std::to_string(loop_size.Get<std::int64_t>());
}
std::string ToTxtString(const List<LoopSize>& loop_sizes) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < loop_sizes->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(loop_sizes.Get(idx));
}
ret += "]";
return ret;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/adt.h"
namespace cinn::adt {
class LoopSize;
std::string ToTxtString(const LoopSize& loop_size);
std::string ToTxtString(const List<LoopSize>& loop_sizes);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/print_utils/print_equations.h"
#include "paddle/cinn/adt/print_utils/print_map_expr.h"
#include "paddle/cinn/adt/print_utils/print_schedule_descriptor.h"
#include "paddle/cinn/adt/print_utils/print_schedule_mesh.h"
#include "paddle/cinn/adt/print_utils/print_value.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
namespace cinn::adt {
constexpr std::size_t kIndentSpaceSize = 2;
namespace {
std::string GetIndentString(std::size_t space_size) {
std::string ret{};
for (std::size_t i = 0; i < space_size; ++i) {
ret += std::string{" "};
}
return ret;
}
} // namespace
template <typename DoEachT>
void VisitEachArg(const List<Arg>& out_args,
const List<Arg>& in_args,
const DoEachT& DoEach) {
for (const auto& out_arg : *out_args) {
DoEach(out_arg, tOut<bool>{true});
}
for (const auto& in_arg : *in_args) {
DoEach(in_arg, tOut<bool>{false});
}
}
std::string ToTxtString(const Tensor& tensor) {
CHECK(tensor.Has<adapter::Tensor>());
std::string ret;
ret += "t_";
ret += hlir::framework::pir::CompatibleInfo::ValueName(
tensor.Get<adapter::Tensor>().node_data);
return ret;
}
std::string ToTxtString(const List<Arg>& out_args,
const List<Arg>& in_args,
bool with_semicolon,
const AnchoredMapStmt* anchored_map_stmt) {
std::string ret;
ret += "(";
std::size_t count = 0;
VisitEachArg(out_args, in_args, [&](const auto& arg, const auto& as_output) {
if (count++ > 0) {
ret += ", ";
}
if (as_output.value()) {
ret += "&";
}
ret += ToTxtString(arg);
ret += "[";
if (anchored_map_stmt != nullptr) {
ret += ToTxtString(anchored_map_stmt->GetTensorIndexExpr(arg));
}
ret += "]";
});
ret += ")";
if (with_semicolon) {
ret += ";\n";
}
return ret;
}
std::string ToTxtStringOpImpl(const ::pir::Operation* op) {
return hlir::framework::pir::CompatibleInfo::OpName(*op);
}
std::string ToTxtStringOpImpl(const tReduceInit<const ::pir::Operation*>& op) {
return ToTxtStringOpImpl(op.value()) + "_init";
}
std::string ToTxtStringOpImpl(const tReduceAcc<const ::pir::Operation*>& op) {
return ToTxtStringOpImpl(op.value()) + "_acc";
}
std::string ToTxtString(const Op& op) {
return std::visit([&](const auto& impl) { return ToTxtStringOpImpl(impl); },
op.variant());
}
std::string ToTxtStringImpl(const OpStmt& op_stmt,
std::size_t indent_size,
const AnchoredMapStmt* anchored_map_stmt) {
std::string ret;
const auto& [op, in_args, out_args] = op_stmt.tuple();
ret += GetIndentString(indent_size * kIndentSpaceSize);
ret += ToTxtString(op);
ret +=
ToTxtString(out_args.value(), in_args.value(), true, anchored_map_stmt);
return ret;
}
std::string ToTxtString(const OpStmt& op_stmt) {
return ToTxtStringImpl(op_stmt, 0, nullptr);
}
std::string ToTxtString(const LoopDescriptors& schedule_descriptor) {
std::string ret;
std::size_t count = 0;
for (const auto& loop_descriptor : *schedule_descriptor) {
if (count++ > 0) {
ret += ", ";
}
ret += ToTxtString(loop_descriptor);
}
return ret;
}
std::string ToTxtStringImpl(const MapStmt<Stmt>& map_stmt,
std::size_t indent_size,
const AnchoredMapStmt* anchored_map_stmt);
std::string ToTxtString(const Stmt& stmt,
std::size_t indent_size,
const AnchoredMapStmt* anchored_map_stmt) {
std::string ret{""};
ret += std::visit(
[&](const auto& impl) {
return ToTxtStringImpl(impl, indent_size, anchored_map_stmt);
},
stmt.variant());
return ret;
}
std::string ToTxtStringImpl(const MapStmt<Stmt>& map_stmt,
std::size_t indent_size,
const AnchoredMapStmt* anchored_map_stmt) {
std::string ret;
const auto& [loop_iterators, stmts] = map_stmt.tuple();
ret += GetIndentString(indent_size * kIndentSpaceSize) + "MapStmt(";
ret += ToTxtString(loop_iterators);
ret += ") {\n";
for (const auto& stmt : *stmts) {
ret += ToTxtString(stmt, indent_size + 1, anchored_map_stmt);
}
ret += GetIndentString(indent_size * kIndentSpaceSize) + "}\n";
return ret;
}
std::string ToTxtString(const AnchoredMapStmt& anchored_map_stmt,
std::size_t indent_size) {
std::string ret;
const auto& [map_stmt, schedule_mesh, anchor_tensor, _0, _1, _2] =
anchored_map_stmt.tuple();
ret += GetIndentString(indent_size * kIndentSpaceSize) + "AnchoredMapStmt(";
ret += ToTxtString(anchor_tensor.value());
ret += ") {\n";
ret += ToTxtString(map_stmt, indent_size + 1, &anchored_map_stmt);
ret += GetIndentString(indent_size * kIndentSpaceSize) + "}\n";
return ret;
}
std::string ToTxtString(const std::string& group_id, const MapExpr& map_expr) {
std::string ret;
const auto& [anchored_map_stmts, inputs, outputs] = map_expr.tuple();
ret += "\n" + group_id;
ret += ToTxtString(outputs.value(), inputs.value(), false, nullptr);
ret += " {\n";
for (const auto& anchored_map_stmt : *anchored_map_stmts) {
ret += ToTxtString(anchored_map_stmt, 1);
}
ret += "}\n";
return ret;
}
std::string ToTxtString(const MapExpr& map_expr, const std::string& group_id) {
std::string ret;
ret += ToTxtString(group_id, map_expr);
return ret;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace cinn::adt {
class Kernel;
using MapExpr = Kernel;
std::string ToTxtString(const OpStmt& op_stmt);
std::string ToTxtString(const MapExpr& map_expr, const std::string& group_id);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_schedule_descriptor.h"
#include "paddle/cinn/adt/schedule_descriptor.h"
namespace cinn::adt {
std::string ToTxtString(const LoopDescriptor& loop_descriptor) {
const auto& [loop_type, loop_size] = loop_descriptor.tuple();
std::string ret{};
auto* string = &ret;
loop_type >>
match{[&](const S0x&) { *string += "blockIdx.x"; },
[&](const S0y&) { *string += "blockIdx.y"; },
[&](const S0z&) { *string += "blockIdx.z"; },
[&](const S1x&) { *string += "threadIdx.x"; },
[&](const S1y&) { *string += "threadIdx.y"; },
[&](const S1z&) { *string += "threadIdx.z"; },
[&](const Temporal& temporal) {
*string += temporal.iter_var_name();
},
[&](const Vectorize& vectorize) {
*string += vectorize.iter_var_name();
},
[&](const Unroll& unroll) { *string += unroll.iter_var_name(); }};
CHECK(loop_size.Has<std::int64_t>());
*string += "=0.." + std::to_string(loop_size.Get<std::int64_t>());
return ret;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace cinn::adt {
class LoopDescriptor;
std::string ToTxtString(const LoopDescriptor& loop_descriptor);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_schedule_dim.h"
#include "paddle/cinn/adt/print_utils/print_loop_size.h"
#include "paddle/cinn/adt/schedule_dim.h"
namespace cinn::adt {
namespace {
std::string ToTxtStringScheduleDimImpl(const tReduced<LoopSize>& loop_size) {
return "R(" + ToTxtString(loop_size.value()) + ")";
}
std::string ToTxtStringScheduleDimImpl(const tInjective<LoopSize>& loop_size) {
return "I(" + ToTxtString(loop_size.value()) + ")";
}
} // namespace
std::string ToTxtString(const ScheduleDim& schedule_dim) {
return std::visit(
[&](const auto& impl) { return ToTxtStringScheduleDimImpl(impl); },
schedule_dim.variant());
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace cinn::adt {
class ScheduleDim;
std::string ToTxtString(const ScheduleDim& schedule_dim);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_schedule_mesh.h"
#include "paddle/cinn/adt/print_utils/print_loop_size.h"
#include "paddle/cinn/adt/print_utils/print_schedule_dim.h"
#include "paddle/cinn/adt/schedule_mesh.h"
namespace cinn::adt {
std::string ToTxtString(const List<int>& ints) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < ints->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += std::to_string(ints.Get(idx));
}
ret += "]";
return ret;
}
namespace {
std::string ToTxtString(const tMeshDim<List<LoopSize>>& mesh_dim_loop_sizes) {
std::string ret;
ret += "dims=" + ToTxtString(mesh_dim_loop_sizes.value());
return ret;
}
std::string ToTxtString(const tMeshPerm<List<int>>& mesh_perm_ints) {
std::string ret;
ret += "perm=" + ToTxtString(mesh_perm_ints.value());
return ret;
}
std::string ToTxtString(
const tMeshPaddingTo<List<LoopSize>>& mesh_padding_loop_sizes) {
std::string ret;
ret += "padding_to=" + ToTxtString(mesh_padding_loop_sizes.value());
return ret;
}
std::string ToTxtStringMeshImpl(const List<ScheduleDim>& schedule_dims) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < schedule_dims->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(schedule_dims.Get(idx));
}
ret += "]";
return ret;
}
std::string ToTxtStringMeshImpl(
const ScheduleMeshReshape<ScheduleMesh>& schedule_mesh_reshape) {
std::string ret;
const auto& [schedule_mesh, loop_sizes] = schedule_mesh_reshape.tuple();
ret += ToTxtString(schedule_mesh);
ret += ".reshape(";
ret += ToTxtString(loop_sizes);
ret += ")";
return ret;
}
std::string ToTxtStringMeshImpl(
const ScheduleMeshTranspose<ScheduleMesh>& schedule_mesh_transpose) {
std::string ret;
const auto& [schedule_mesh, loop_sizes] = schedule_mesh_transpose.tuple();
ret += ToTxtString(schedule_mesh);
ret += ".transpose(";
ret += ToTxtString(loop_sizes);
ret += ")";
return ret;
}
std::string ToTxtStringMeshImpl(
const ScheduleMeshPadding<ScheduleMesh>& schedule_mesh_padding) {
std::string ret;
const auto& [schedule_mesh, loop_sizes] = schedule_mesh_padding.tuple();
ret += ToTxtString(schedule_mesh);
ret += ".padding(";
ret += ToTxtString(loop_sizes);
ret += ")";
return ret;
}
} // namespace
std::string ToTxtString(const ScheduleMesh& schecule_mesh) {
return std::visit([&](const auto& impl) { return ToTxtStringMeshImpl(impl); },
schecule_mesh.variant());
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace cinn::adt {
class ScheduleMesh;
std::string ToTxtString(const ScheduleMesh& schecule_mesh);
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/print_utils/print_value.h"
#include "paddle/cinn/adt/equation_value.h"
#include "paddle/cinn/adt/print_utils/print_constant.h"
#include "paddle/cinn/adt/print_utils/print_equations.h"
namespace cinn::adt {
namespace {
std::string ToTxtString(const tPointer<UniqueId>& value) {
std::size_t value_unique_id = value.value().unique_id();
return "ptr_" + std::to_string(value_unique_id);
}
struct ToTxtStringStruct {
std::string operator()(const Undefined& value) { return "undefined"; }
std::string operator()(const Ok& value) { return "ok"; }
std::string operator()(const Iterator& value) {
std::string ret;
ret += ToTxtString(value);
return ret;
}
std::string operator()(const Constant& value) {
std::string ret;
ret += ToTxtString(value);
return ret;
}
std::string operator()(const List<Value>& value_list) {
std::string ret;
ret += "[";
for (std::size_t idx = 0; idx < value_list->size(); ++idx) {
if (idx != 0) {
ret += ", ";
}
ret += ToTxtString(value_list.Get(idx));
}
ret += "]";
return ret;
}
std::string operator()(const IndexDotValue<Value, Constant>& value) {
std::string ret;
const auto& [iters, constant] = value.tuple();
ret +=
"IndexDot(" + ToTxtString(iters) + ", " + ToTxtString(constant) + ")";
return ret;
}
std::string operator()(const IndexUnDotValue<Value, Constant>& value) {
std::string ret;
const auto& [_, constant] = value.tuple();
const Value& value_ = value.GetIndexValue();
ret += "IndexUnDot(" + ToTxtString(value_) + ", " + ToTxtString(constant) +
")";
return ret;
}
std::string operator()(const ListGetItem<Value, Constant>& list_get_item) {
std::string ret;
const auto& [value, constant] = list_get_item.tuple();
ret += "ListGetItem(" + ToTxtString(value) + ", " + ToTxtString(constant) +
")";
return ret;
}
std::string operator()(
const BroadcastedIterator<Value, Constant>& broadcast) {
std::string ret;
const auto& [value, constant] = broadcast.tuple();
ret += "BroadcastedIterator(" + ToTxtString(value) + ", " +
ToTxtString(constant) + ")";
return ret;
}
std::string operator()(const PtrGetItem<Value>& ptr_get_item) {
std::string ret;
const auto& [ptr_tag, value] = ptr_get_item.tuple();
ret +=
"PtrGetItem(" + ToTxtString(ptr_tag) + ", " + ToTxtString(value) + ")";
return ret;
}
};
} // namespace
std::string ToTxtString(const Value& value) {
return std::visit(ToTxtStringStruct{}, value.variant());
}
std::string ToTxtString(const std::optional<Value>& opt_value) {
if (opt_value.has_value()) {
return ToTxtString(opt_value.value());
} else {
return "";
}
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <optional>
#include <string>
namespace cinn::adt {
class Value;
std::string ToTxtString(const Value& value);
std::string ToTxtString(const std::optional<Value>& opt_value);
} // namespace cinn::adt
### 一、如何运行
#### 1.1 开启 FLAGS_cinn_enable_map_expr:
```
export FLAGS_cinn_enable_map_expr=true
```
#### 1.2 执行 python 脚本:
```
cd test/cinn/op && python test_relu_expr.py
```
### 二、输出预览
#### 2.1 简单示例:以 Tensor x 为输入,执行 sin 和 relu 两个算子
```
builder = NetBuilder("MapExprTest")
x = builder.create_input(Float(32), inputs["x"].shape, "x")
y = builder.sin(x)
out = builder.relu(y)
```
#### 2.2 输出结果
```
fill_constant_1_sin_0_max_2(&t_var_1, t_x) {
AnchoredMapStmt(t_var_0) {
MapStmt(blockIdx.x=0..1, threadIdx.x=0..64) {
fill_constant(&t_zero);
sin(&t_var_0, t_x);
max(&t_var_1, t_var_0, t_zero);
}
}
}
```
#### 2.3 各字段含义
| 字段 | 含义 |
| :------------ | :------------ |
| fill_constant_1_sin_0_max_2(&t_var_1, t_x) | MapExpr 名称为 fill_constant_1_sin_0_max_2(即当前 group 对应的 group_id),该 MapExpr 以 t_var_1 为输出,t_x 为输入,&为输出 Tensor 标识符|
| AnchoredMapStmt(t_var_0) | 以 t_var_0 为 AnchorTensor 的一系列 Stmt,从 t_var_0 的下标索引可以推断出 Stmt 内所有其他 Tensor 的下标 |
| MapStmt(blockIdx.x=0..1, threadIdx.x=0..64) | MapStmt 内所有 op 遵循如下调度策略:blockIdx.x 的取值为从 0 到 1,threadIdx.x 的取值为从 0 到 64 |
| fill_constant(&t_zero) | fill_constant 算子的输出 Tensor 为 t_zero |
| sin(&t_var_0, t_x) | sin 算子的输出 Tensor 为 t_var_0,输入 Tensor 为 t_x |
| max(&t_var_1, t_var_0, t_zero) | max 算子的输出 Tensor 为 t_var_1,输入 Tensor 为 t_var_0 和 t_zero |
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/adt/schedule_descriptor.h"
#include "paddle/cinn/adt/equation_solver.h"
#include "paddle/cinn/adt/igroup.h"
#include "paddle/cinn/adt/index_expr_infer_context.h"
#include "paddle/cinn/adt/kgroup.h"
#include "paddle/cinn/adt/schedule_dim.h"
namespace cinn::adt {
namespace {
std::vector<int32_t> GetTensorShape(const Tensor& tensor) {
CHECK(tensor.Has<adapter::Tensor>());
return tensor.Get<adapter::Tensor>().GetShape();
}
} // namespace
LoopDescriptors CreateScheduleDescriptor(const ScheduleMesh& sched_mesh,
const List<LoopType>& loop_types) {
const auto& sched_dims = GetOutputDimValues(sched_mesh);
CHECK_EQ(sched_dims->size(), loop_types->size());
LoopDescriptors ret{};
for (std::size_t i = 0; i < sched_dims->size(); ++i) {
const auto& sched_dim = sched_dims->at(i);
CHECK(sched_dim.Has<std::int64_t>());
ret->emplace_back(LoopDescriptor{loop_types->at(i),
LoopSize{sched_dim.Get<std::int64_t>()}});
}
return ret;
}
LoopDescriptors MakeNaiveScheduleDescriptor(
const std::shared_ptr<KGroup>& kgroup,
const std::shared_ptr<IGroup>& igroup) {
const Tensor& tensor = igroup->anchor_tensor();
List<LoopDescriptor> ret{};
const std::vector<int32_t> tensor_shape = GetTensorShape(tensor);
for (int32_t dim : tensor_shape) {
ret->emplace_back(LoopDescriptor{Temporal{}, dim});
}
return ret;
}
List<LoopSize> GenerateLoopSizeFromSd(const LoopDescriptors& sd) {
List<LoopSize> sd_sizes{};
for (const auto& loop_descriptor : *sd) {
const auto& [loop_type, loop_size] = loop_descriptor.tuple();
sd_sizes->emplace_back(loop_size);
}
return sd_sizes;
}
} // namespace cinn::adt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment