Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -26,7 +26,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string(model_dir, "", "");
PD_DEFINE_string(model_dir, "", "");
namespace cinn {
namespace frontend {
......@@ -66,7 +66,8 @@ TEST(batch_norm_meta, batch_norm_meta) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -104,7 +105,8 @@ TEST(reduction, reduce) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......@@ -136,7 +138,8 @@ TEST(Compare, Compare) {
auto scope = BuildScope(target, graph);
LOG(INFO) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A");
......
......@@ -16,6 +16,10 @@ gather_srcs(
transform.cc
vision.cc)
if(NOT CINN_ONLY)
gather_srcs(cinnapi_src SRCS map_expr_to_ir.cc)
endif()
cinn_cc_test(test_cinn_pe_elementwise SRCS pe_elementwise_test.cc DEPS cinncore)
cinn_cc_test(test_cinn_pe_broadcast SRCS pe_broadcast_test.cc DEPS cinncore)
cinn_cc_test(test_cinn_pe_transform SRCS pe_transform_test.cc DEPS cinncore)
......
......@@ -31,14 +31,39 @@
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/poly/isl_utils.h"
#include "paddle/cinn/utils/string.h"
PD_DECLARE_bool(cinn_new_group_scheduler);
namespace cinn {
namespace hlir {
namespace pe {
void SetReduceAxis(ir::Expr loop, ir::Expr block) {
std::string var_name = loop.As<ir::For>()->loop_var->name;
std::vector<ir::Var> iter_vars = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
std::vector<ir::Expr> iter_values =
block.As<ir::ScheduleBlockRealize>()->iter_values;
CHECK_EQ(iter_vars.size(), iter_values.size());
for (int i = 0; i < iter_values.size(); ++i) {
std::set<Expr> contains = ir::ir_utils::CollectIRNodesWithoutTensor(
iter_values[i],
[&var_name](const Expr *expr) {
return expr->As<ir::_Var_>() != nullptr &&
expr->As<ir::_Var_>()->name == var_name;
},
true);
if (!contains.empty()) {
iter_vars[i]->is_reduce_axis = true;
}
}
}
void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
const std::vector<int> &output_shape,
const common::Target &target) {
......@@ -46,15 +71,15 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
if (target == common::DefaultNVGPUTarget()) {
auto blocks = ir_sch.GetAllBlocks();
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), true);
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir::Expr loop = ir_sch.Fuse(loops);
auto loops = ir_sch.GetLoops(blocks[0]);
auto size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (size <= target.max_num_threads()) {
ir_sch.Bind(loops[0], "threadIdx.x");
ir_sch.Bind(loop, "threadIdx.x");
} else {
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
ir_sch.Bind(splited[0], "blockIdx.x");
ir_sch.Bind(splited[1], "threadIdx.x");
}
......@@ -74,15 +99,15 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT
<< ir_sch.GetModule().GetExprs().at(0);
if (target == common::DefaultNVGPUTarget()) {
auto blocks = ir_sch.GetAllBlocks();
ir_sch.FlattenLoops(ir_sch.GetLoops(blocks[0]), false);
std::vector<ir::Expr> loops = ir_sch.GetLoops(blocks[0]);
ir::Expr loop = ir_sch.Fuse(loops);
auto loops = ir_sch.GetLoops(blocks[0]);
auto size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (size <= target.max_num_threads()) {
ir_sch.Bind(loops[0], "threadIdx.x");
ir_sch.Bind(loop, "threadIdx.x");
} else {
auto splited = ir_sch.Split(loops[0], {-1, target.max_num_threads()});
auto splited = ir_sch.Split(loop, {-1, target.max_num_threads()});
ir_sch.Bind(splited[0], "blockIdx.x");
ir_sch.Bind(splited[1], "threadIdx.x");
}
......@@ -457,9 +482,15 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
if (loops_tmp_out.size() == 1) {
ir_sch.Bind(loops_tmp_out[0], "threadIdx.x");
ir_sch.Bind(loops_out[0], "threadIdx.x");
if (FLAGS_cinn_new_group_scheduler) {
SetReduceAxis(loops_tmp_out[0], ir_sch.GetBlock(tmp_out->name));
}
} else {
ir_sch.Bind(loops_tmp_out[0], "blockIdx.x");
ir_sch.Bind(loops_tmp_out[1], "threadIdx.x");
if (FLAGS_cinn_new_group_scheduler) {
SetReduceAxis(loops_tmp_out[1], ir_sch.GetBlock(tmp_out->name));
}
if (loops_out.size() == 1) {
ir_sch.Split(loops_out[0], {-1, 1});
......@@ -471,8 +502,12 @@ void IRCudaScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
for (auto &tensor : {tmp_out}) {
auto block = ir_sch.GetBlock(tensor->name);
if (FLAGS_cinn_new_group_scheduler) {
ir_sch.SetBuffer(block, "local");
} else {
ir_sch.SetBuffer(block, "local", true);
}
}
VLOG(3) << "After IRCudaScheduleBlockReduceInternal : "
<< ir_sch.GetModule().GetExprs().at(0);
......@@ -600,6 +635,9 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x");
if (FLAGS_cinn_new_group_scheduler) {
SetReduceAxis(loops[1], ir_sch.GetBlock(tmp_out->name));
}
}
// out
{
......@@ -614,8 +652,12 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
for (auto &tensor : {reduce_tmp_out, tmp_out}) {
auto block = ir_sch.GetBlock(tensor->name);
if (FLAGS_cinn_new_group_scheduler) {
ir_sch.SetBuffer(block, "local");
} else {
ir_sch.SetBuffer(block, "local", true);
}
}
VLOG(3) << "After IRCudaScheduleBlockReduce : "
<< ir_sch.GetModule().GetExprs().at(0);
......@@ -633,7 +675,7 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
// simplify reshape index
auto hand_write_simplify = [](std::vector<ir::Expr> loops, ir::Expr block) {
// check exist select.
auto find_select = ir::CollectIRNodesInOrder(
auto find_select = ir::ir_utils::CollectIRNodesInOrder(
block, [&](const Expr *x) { return x->As<ir::Select>(); });
if (find_select.size() > 0) {
return;
......@@ -667,14 +709,16 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
index = index + ir::Expr(schedule_block->iter_vars[idx]) * stride;
}
auto exprs = ir::CollectIRNodesInOrder(
auto exprs = ir::ir_utils::CollectIRNodesInOrder(
block, [&](const Expr *x) { return x->As<ir::Load>(); });
CHECK_EQ(exprs.size(), 1);
auto load = exprs.front().As<ir::Load>();
load->indices = {index};
};
if (!FLAGS_cinn_new_group_scheduler) {
hand_write_simplify(ir_sch.GetLoops(reshape->name),
ir_sch.GetBlock(reshape->name));
}
auto block = ir_sch.GetBlock(reshape->name);
ir_sch.ComputeInline(block);
VLOG(4) << "After simplify reshape index : "
......@@ -709,7 +753,7 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT
break;
}
auto exprs = ir::CollectIRNodesInOrder(
auto exprs = ir::ir_utils::CollectIRNodesInOrder(
block, [&](const Expr *x) { return x->As<ir::Load>(); });
for (auto expr : exprs) {
auto load = expr.As<ir::Load>();
......@@ -955,10 +999,14 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
ir_sch.ComputeInline(reshape_block);
auto internal_block = ir_sch.GetBlock(internal->name);
ir_sch.SetBuffer(internal_block, "local", true);
auto tmp_out_block = ir_sch.GetBlock(tmp_out->name);
if (FLAGS_cinn_new_group_scheduler) {
ir_sch.SetBuffer(internal_block, "local");
ir_sch.SetBuffer(tmp_out_block, "local");
} else {
ir_sch.SetBuffer(internal_block, "local", true);
ir_sch.SetBuffer(tmp_out_block, "local", true);
}
// The current one-dimensional reduce does not make full use of SM.
// This case is optimized into a two-dimensional.
......@@ -978,9 +1026,15 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.y");
ir_sch.Bind(loops[2], "threadIdx.x");
if (FLAGS_cinn_new_group_scheduler && tensor->name == tmp_out->name) {
SetReduceAxis(loops[2], ir_sch.GetBlock(tmp_out->name));
}
} else {
ir_sch.Bind(loops[0], "blockIdx.x");
ir_sch.Bind(loops[1], "threadIdx.x");
if (FLAGS_cinn_new_group_scheduler && tensor->name == tmp_out->name) {
SetReduceAxis(loops[1], ir_sch.GetBlock(tmp_out->name));
}
}
}
VLOG(3) << "After IRCudaTwoStepReduceSchedule : "
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/pe/map_expr_to_ir.h"
#include <unordered_map>
#include <vector>
#include "paddle/cinn/adt/equation_value_match_trait.h"
#include "paddle/cinn/adt/inline_translator.h"
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/adt/match.h"
#include "paddle/cinn/adt/no_inline_translator.h"
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/flags.h"
PD_DECLARE_bool(cinn_enable_map_expr_inline);
namespace cinn::adt {
namespace {
using IteratorInt = std::int32_t;
using Node2LoweredFuncs =
std::unordered_map<::pir::Operation*, std::vector<ir::LoweredFunc>>;
using TensorIteratorExpr4TensorT =
std::function<adt::List<adt::TensorIteratorExpr>(const adt::Tensor&)>;
using IterExprs4TensorT =
std::function<std::vector<ir::Expr>(const adt::Tensor&)>;
using LoopDescriptor4LoopIteratorT =
std::function<adt::LoopDescriptor(const adt::Iterator&)>;
class MapExprToIrTranslator {
public:
explicit MapExprToIrTranslator(const MapExpr& map_expr,
const Node2LoweredFuncs& node2lowered_funcs,
const common::Target& target)
: map_expr_(map_expr),
node2lowered_funcs_(&node2lowered_funcs),
target_(target) {
const auto& [anchored_map_stmts, _0, _1] = map_expr.tuple();
CHECK_EQ(anchored_map_stmts->size(), 1);
TensorIteratorExpr4Tensor = std::get<4>(anchored_map_stmts->at(0).tuple());
LoopDescriptor4LoopIterator =
std::get<5>(anchored_map_stmts->at(0).tuple());
}
ir::Expr Translate() const {
VLOG(1) << "Translate MapExpr: ";
VLOG(1) << ToTxtString(map_expr_, "");
return ir::Block::Make({Translate(map_expr_)});
}
private:
ir::Expr GetStoreExprForOp(const ::pir::Operation* op) const {
const auto& iter =
node2lowered_funcs_->find(const_cast<::pir::Operation*>(op));
CHECK(iter != node2lowered_funcs_->end());
const auto& lowered_funcs = iter->second;
CHECK_EQ(lowered_funcs.size(), 1);
std::optional<ir::Expr> ret{std::nullopt};
VisitEachStoreExpr(lowered_funcs.at(0), [&](const ir::Expr& expr) {
CHECK(!ret.has_value());
ret = expr;
});
CHECK(ret.has_value());
return ret.value();
}
ir::Expr GetStoreExprForOp(
const tReduceInit<const ::pir::Operation*>& op) const {
const auto& iter =
node2lowered_funcs_->find(const_cast<::pir::Operation*>(op.value()));
CHECK(iter != node2lowered_funcs_->end());
const auto& lowered_funcs = iter->second;
CHECK_EQ(lowered_funcs.size(), 1);
std::vector<ir::Expr> stores{};
VisitEachStoreExpr(lowered_funcs.at(0), [&](const ir::Expr& expr) {
stores.emplace_back(expr);
});
CHECK_EQ(stores.size(), 2);
return stores.at(0);
}
ir::Expr GetStoreExprForOp(
const tReduceAcc<const ::pir::Operation*>& op) const {
const auto& iter =
node2lowered_funcs_->find(const_cast<::pir::Operation*>(op.value()));
CHECK(iter != node2lowered_funcs_->end());
const auto& lowered_funcs = iter->second;
CHECK_EQ(lowered_funcs.size(), 1);
std::vector<ir::Expr> stores{};
VisitEachStoreExpr(lowered_funcs.at(0), [&](const ir::Expr& expr) {
stores.emplace_back(expr);
});
CHECK_EQ(stores.size(), 2);
return stores.at(1);
}
std::optional<ir::Expr> GetStoreExprImpl(
const OpCall<OpExpr>& op_expr) const {
const auto& [op, _] = op_expr.tuple();
return std::visit([&](const auto& impl) { return GetStoreExprForOp(impl); },
op.variant());
}
std::optional<ir::Expr> GetStoreExprImpl(const Load<Tensor>& load) const {
return std::nullopt;
}
// using OpExpr = Tree<OpCall, Load<Tensor>>;
std::optional<ir::Expr> GetStoreExpr(const OpExpr& op_expr) const {
return std::visit([&](const auto& impl) { return GetStoreExprImpl(impl); },
op_expr.variant());
}
template <typename DoEachT>
void VisitEachStoreExpr(const ir::Expr& expr, const DoEachT& DoEach) const {
switch (expr.node_type()) {
case ir::IrNodeTy::_LoweredFunc_:
VisitEachStoreExpr(expr.as_lowered_func()->body, DoEach);
break;
case ir::IrNodeTy::Block:
for (const auto& stmt : expr.As<ir::Block>()->stmts) {
VisitEachStoreExpr(stmt, DoEach);
}
break;
case ir::IrNodeTy::ScheduleBlockRealize:
VisitEachStoreExpr(expr.As<ir::ScheduleBlockRealize>()->schedule_block,
DoEach);
break;
case ir::IrNodeTy::ScheduleBlock:
VisitEachStoreExpr(expr.As<ir::ScheduleBlock>()->body, DoEach);
break;
case ir::IrNodeTy::For:
VisitEachStoreExpr(expr.As<ir::For>()->body, DoEach);
break;
case ir::IrNodeTy::Store:
DoEach(expr);
break;
default:
LOG(FATAL) << "Visit node_type = " << expr.node_type()
<< ", not supported!";
break;
}
}
template <typename DoEachT>
void VisitEachStmt(const List<Stmt>& stmts, const DoEachT& DoEach) const {
for (const auto& stmt : *stmts) {
DoEach(stmt);
}
}
ir::Expr Translate(const MapExpr& map_expr) const {
const auto& [anchored_map_stmts, _0, _1] = map_expr.tuple();
CHECK_EQ(anchored_map_stmts->size(), 1);
return Translate(anchored_map_stmts->at(0));
}
ir::Expr Translate(const AnchoredMapStmt& anchored_map_stmt) const {
const MapStmt<Stmt>& map_stmt = std::get<0>(anchored_map_stmt.tuple());
ir::Expr ret = Translate(map_stmt);
ret = ir::ScheduleBlock::Make({}, {}, {}, "root", ret);
ret = ir::ScheduleBlockRealize::Make({}, ret);
return ret;
}
using InternalLeafStmt = Store<Tensor, OpCall<Load<Tensor>>>;
using InternalStmt = Tree<MapStmt, InternalLeafStmt>;
InternalStmt ConvertToInternalStmtImpl(const OpStmt& op_stmt) const {
const auto& [op, inputs, outputs] = op_stmt.tuple();
CHECK_EQ(outputs.value()->size(), 1);
List<Load<Tensor>> loads{};
for (const auto& in : *inputs.value()) {
loads->emplace_back(Load<Tensor>{in});
}
OpCall<Load<Tensor>> op_call{op, loads};
return InternalLeafStmt{outputs.value()->at(0), op_call};
}
InternalStmt ConvertToInternalStmtImpl(const MapStmt<Stmt>& map_stmt) const {
const auto& [iterators, stmts] = map_stmt.tuple();
List<InternalStmt> children{};
for (const auto& stmt : *stmts) {
children->emplace_back(ConvertToInternalStmt(stmt));
}
return MapStmt<InternalStmt>{iterators, children};
}
InternalStmt ConvertToInternalStmt(const Stmt& stmt) const {
return std::visit(
[&](const auto& impl) { return ConvertToInternalStmtImpl(impl); },
stmt.variant());
}
InlineStmt ConvertToInlineStmt(const InternalStmt& internal_stmt) const {
if (FLAGS_cinn_enable_map_expr_inline) {
return InlineTranslator<MapStmt, OpCall, Tensor>::Call(internal_stmt);
} else {
return NoInlineTranslator<MapStmt, OpCall, Tensor>::Call(internal_stmt);
}
LOG(FATAL) << "Dead code";
}
std::optional<ir::Expr> TranslateOpExprImpl(
const Load<Tensor>& load,
const std::optional<Tensor>& opt_output_tensor,
const IterExprs4TensorT& IterExprs4Tensor) const {
return std::nullopt;
}
std::vector<ir::Expr> TranslateTensorIndexImpl(
const OpCall<OpExpr>& op_call,
const IterExprs4TensorT& IterExprs4Tensor) const {
LOG(FATAL) << "Dead code, no TensorIndexExpr for OpCall";
}
std::vector<ir::Expr> TranslateTensorIndexImpl(
const Load<Tensor>& op_call,
const IterExprs4TensorT& IterExprs4Tensor) const {
const auto& [tensor] = op_call.tuple();
return IterExprs4Tensor(tensor);
}
// using OpExpr = Tree<OpCall, Load<Tensor>>;
std::vector<ir::Expr> TranslateTensorIndex(
const OpExpr& op_expr, const IterExprs4TensorT& IterExprs4Tensor) const {
return std::visit(
[&](const auto& impl) {
return TranslateTensorIndexImpl(impl, IterExprs4Tensor);
},
op_expr.variant());
}
std::optional<ir::Expr> MakeLoadExpr(
const ir::Expr& input_expr,
const List<OpExpr>& op_expr_children,
const IterExprs4TensorT& IterExprs4Tensor) const {
ir::Expr store_rvalue = ir::ir_utils::IRCopy(input_expr);
CHECK_EQ(store_rvalue->operands.size(), 0);
CHECK_EQ(op_expr_children->size(), 1);
store_rvalue.As<ir::Load>()->indices =
TranslateTensorIndex(op_expr_children->at(0), IterExprs4Tensor);
return store_rvalue;
}
std::optional<ir::Expr> MakeCallExpr(
const ir::Expr& input_expr,
const List<OpExpr>& op_expr_children,
const IterExprs4TensorT& IterExprs4Tensor) const {
ir::Expr store_rvalue = ir::ir_utils::IRCopy(input_expr);
CHECK_EQ(store_rvalue->operands.size(), 0);
CHECK(!op_expr_children->empty());
CHECK_EQ((store_rvalue.As<ir::Call>()->read_args.size()),
(op_expr_children->size()));
for (int i = 0; i < op_expr_children->size(); ++i) {
const auto& opt_operant = TranslateOpExpr(
op_expr_children->at(i), std::nullopt, IterExprs4Tensor);
if (opt_operant.has_value()) {
store_rvalue.As<ir::Call>()->read_args.at(i) = opt_operant.value();
} else {
store_rvalue.As<ir::Call>()->read_args.at(i).As<ir::Load>()->indices =
TranslateTensorIndex(op_expr_children->at(i), IterExprs4Tensor);
}
}
return store_rvalue;
}
std::optional<ir::Expr> MakeGeneralExpr(
const ir::Expr& input_expr,
const List<OpExpr>& op_expr_children,
const IterExprs4TensorT& IterExprs4Tensor) const {
ir::Expr store_rvalue = ir::ir_utils::IRCopy(input_expr);
CHECK_EQ(store_rvalue->operands.size(), op_expr_children->size());
for (int i = 0; i < op_expr_children->size(); ++i) {
const auto& opt_operant = TranslateOpExpr(
op_expr_children->at(i), std::nullopt, IterExprs4Tensor);
if (opt_operant.has_value()) {
store_rvalue->operands.at(i) = opt_operant.value();
} else {
store_rvalue->operands.at(i).As<ir::Load>()->indices =
TranslateTensorIndex(op_expr_children->at(i), IterExprs4Tensor);
}
}
return store_rvalue;
}
std::optional<ir::Expr> TranslateOpCallImpl(
const ::pir::Operation*,
const OpCall<OpExpr>& op_expr,
const std::optional<Tensor>& opt_output_tensor,
const IterExprs4TensorT& IterExprs4Tensor) const {
const auto& [_, op_expr_children] = op_expr.tuple();
std::optional<ir::Expr> store_expr = GetStoreExpr(op_expr);
CHECK(store_expr.has_value());
ir::Expr store_rvalue = store_expr.value().As<ir::Store>()->value;
if (store_rvalue.As<ir::Load>()) {
return MakeLoadExpr(store_rvalue, op_expr_children, IterExprs4Tensor);
} else if (store_rvalue.As<ir::Call>()) {
return MakeCallExpr(store_rvalue, op_expr_children, IterExprs4Tensor);
} else {
if (!op_expr_children->empty()) {
return MakeGeneralExpr(
store_rvalue, op_expr_children, IterExprs4Tensor);
} else {
// Do nothing
}
}
return store_rvalue;
}
std::optional<ir::Expr> TranslateOpCallImpl(
const tReduceInit<const ::pir::Operation*>&,
const OpCall<OpExpr>& op_expr,
const std::optional<Tensor>& opt_output_tensor,
const IterExprs4TensorT& IterExprs4Tensor) const {
const auto& [_, op_expr_children] = op_expr.tuple();
std::optional<ir::Expr> store_expr = GetStoreExpr(op_expr);
CHECK(store_expr.has_value());
ir::Expr store_rvalue = store_expr.value().As<ir::Store>()->value;
VLOG(1) << "tReduceInit store_rvalue:\n" << store_rvalue;
CHECK_EQ(store_rvalue->operands.size(), 0);
CHECK_EQ(op_expr_children->size(), 0);
return store_rvalue;
}
std::optional<ir::Expr> TranslateOpCallImpl(
const tReduceAcc<const ::pir::Operation*>&,
const OpCall<OpExpr>& op_expr,
const std::optional<Tensor>& opt_output_tensor,
const IterExprs4TensorT& IterExprs4Tensor) const {
const auto& [_, op_expr_children] = op_expr.tuple();
std::optional<ir::Expr> store_expr = GetStoreExpr(op_expr);
CHECK(store_expr.has_value());
ir::Expr store_rvalue = store_expr.value().As<ir::Store>()->value;
VLOG(1) << "tReduceAcc store_rvalue:\n" << store_rvalue;
CHECK_EQ(store_rvalue->operands.size(), 2);
CHECK_EQ(op_expr_children->size(), 1);
CHECK(opt_output_tensor.has_value());
store_rvalue->operands.at(0).As<ir::Load>()->indices =
IterExprs4Tensor(opt_output_tensor.value());
{
const auto& opt_operant = TranslateOpExpr(
op_expr_children->at(0), std::nullopt, IterExprs4Tensor);
if (opt_operant.has_value()) {
store_rvalue->operands.at(1) = opt_operant.value();
} else {
store_rvalue->operands.at(1).As<ir::Load>()->indices =
TranslateTensorIndex(op_expr_children->at(0), IterExprs4Tensor);
}
}
return store_rvalue;
}
std::optional<ir::Expr> TranslateOpExprImpl(
const OpCall<OpExpr>& op_expr,
const std::optional<Tensor>& opt_output_tensor,
const IterExprs4TensorT& IterExprs4Tensor) const {
const auto& [op, op_expr_children] = op_expr.tuple();
return std::visit(
[&](const auto& impl) {
return TranslateOpCallImpl(
impl, op_expr, opt_output_tensor, IterExprs4Tensor);
},
op.variant());
}
// using OpExpr = Tree<OpCall, Load<Tensor>>;
std::optional<ir::Expr> TranslateOpExpr(
const OpExpr& op_expr,
const std::optional<Tensor>& opt_output_tensor,
const IterExprs4TensorT& IterExprs4Tensor) const {
return std::visit(
[&](const auto& impl) {
return TranslateOpExprImpl(impl, opt_output_tensor, IterExprs4Tensor);
},
op_expr.variant());
}
ir::Expr TranslateImpl(const OpExprStmt& op_expr_stmt) const {
return Translate(op_expr_stmt);
}
template <typename DoEachT /*void(&)(const Value&)*/>
void VisitEachIteratorValue(const Tensor& tensor,
const DoEachT& DoEach) const {
const List<Value>& iterator_values = TensorIteratorExpr4Tensor(tensor);
for (const auto& iterator_value : *iterator_values) {
DoEach(iterator_value);
}
}
template <typename DoEachT /*void(&)(const Value&)*/>
void VisitEachIteratorValueImpl(const OpCall<OpExpr>& op_call,
const DoEachT& DoEach) const {
const auto& [_, children] = op_call.tuple();
for (const auto& child : *children) {
VisitEachIteratorValue(child, DoEach);
}
}
template <typename DoEachT /*void(&)(const Value&)*/>
void VisitEachIteratorValueImpl(const Load<Tensor>& load,
const DoEachT& DoEach) const {
const auto& [tensor] = load.tuple();
VisitEachIteratorValue(tensor, DoEach);
}
template <typename DoEachT /*void(&)(const Value&)*/>
void VisitEachIteratorValue(const OpExpr& op_expr,
const DoEachT& DoEach) const {
return std::visit(
[&](const auto& impl) {
return VisitEachIteratorValueImpl(impl, DoEach);
},
op_expr.variant());
}
template <typename DoEachT /*void(&)(const Value&)*/>
void VisitEachIteratorValue(const OpExprStmt& op_expr_stmt,
const DoEachT& DoEach) const {
const auto& [tensor, op_expr] = op_expr_stmt.tuple();
VisitEachIteratorValue(tensor, DoEach);
VisitEachIteratorValue(op_expr, DoEach);
}
IterExprs4TensorT MakeGetterIterExprs4Tensor(
const OpExprStmt& op_expr_stmt,
std::vector<std::pair<ir::Var, ir::Expr>>* binding_var2value) const {
std::unordered_map<Value, std::pair<ir::Var, ir::Expr>> value2var_expr{};
VisitEachIteratorValue(op_expr_stmt, [&](const Value& value) {
if (value2var_expr.count(value) == 0) {
ir::Var var{std::string("m_expr_i_") +
std::to_string(UniqueId::New().unique_id())};
ir::Expr expr = TranslateTensorIterator(value);
CHECK(value2var_expr.emplace(value, std::make_pair(var, expr)).second);
} else {
// Do nothing
}
});
for (const auto& [_, pair] : value2var_expr) {
binding_var2value->push_back(pair);
}
return [value2var_expr, this](const Tensor& tensor) {
const List<Value>& iterator_values = TensorIteratorExpr4Tensor(tensor);
std::vector<ir::Expr> ret{};
ret.reserve(iterator_values->size());
for (const auto& iterator_value : *iterator_values) {
const auto& it = value2var_expr.find(iterator_value);
CHECK(it != value2var_expr.end());
ret.emplace_back(it->second.first);
}
return ret;
};
}
std::vector<ir::Var> GetVectorOfPairFirst(
const std::vector<std::pair<ir::Var, ir::Expr>>& pairs) const {
std::vector<ir::Var> ret{};
ret.reserve(pairs.size());
for (const auto& pair : pairs) {
ret.emplace_back(pair.first);
}
return ret;
}
std::vector<ir::Expr> GetVectorOfPairSecond(
const std::vector<std::pair<ir::Var, ir::Expr>>& pairs) const {
std::vector<ir::Expr> ret{};
ret.reserve(pairs.size());
for (const auto& pair : pairs) {
ret.emplace_back(pair.second);
}
return ret;
}
// using OpExprStmt = Store<Tensor, OpExpr>;
ir::Expr Translate(const OpExprStmt& op_expr_stmt) const {
const auto& [output_tensor, op_expr] = op_expr_stmt.tuple();
std::optional<ir::Expr> store_expr = GetStoreExpr(op_expr);
CHECK(store_expr.has_value());
std::optional<Tensor> opt_output_tensor = output_tensor;
std::vector<std::pair<ir::Var, ir::Expr>> binding_var2value{};
const auto& IterExprs4Tensor =
MakeGetterIterExprs4Tensor(op_expr_stmt, &binding_var2value);
const auto& opt_rvalue =
TranslateOpExpr(op_expr, opt_output_tensor, IterExprs4Tensor);
CHECK(opt_rvalue.has_value());
const auto& output_expr =
ir::Store::Make(store_expr.value().As<ir::Store>()->tensor,
opt_rvalue.value(),
IterExprs4Tensor(output_tensor));
ir::Expr ret = ir::ScheduleBlock::Make(
GetVectorOfPairFirst(binding_var2value),
{},
{},
output_expr.As<ir::Store>()->tensor.as_tensor()->name,
output_expr);
ret = ir::ScheduleBlockRealize::Make(
GetVectorOfPairSecond(binding_var2value), ret);
return ret;
}
ir::Expr Translate(const List<InlineStmt>& stmts) const {
std::vector<ir::Expr> exprs;
for (const auto& stmt : *stmts) {
exprs.emplace_back(Translate(stmt));
}
return ir::Block::Make(exprs);
}
ir::Expr TranslateImpl(const MapStmt<InlineStmt>& map_stmt) const {
const auto& [iterators, stmts] = map_stmt.tuple();
ir::Expr ret = Translate(stmts);
CHECK_GT(iterators->size(), 0);
for (int i = iterators->size() - 1; i >= 0; --i) {
const auto& iterator = iterators->at(i);
const auto& ld = LoopDescriptor4LoopIterator(iterator);
ir::Var var{"v_" + std::to_string(iterator.value().unique_id())};
ir::Expr min{IteratorInt(0)};
ir::Expr extent = GetLoopSize(ld);
const auto& [for_type, vectorize_info, bind_info] = GetForTypeAndInfo(ld);
ir::DeviceAPI device_api = GetDeviceApi();
ret = ir::For::Make(var,
min,
extent,
for_type,
device_api,
ret,
vectorize_info,
bind_info);
}
return ret;
}
ir::Expr Translate(const InlineStmt& inline_stmt) const {
return std::visit([&](const auto& impl) { return TranslateImpl(impl); },
inline_stmt.variant());
}
ir::Expr Translate(const MapStmt<Stmt>& map_stmt) const {
Stmt stmt = map_stmt;
InternalStmt internal_stmt = ConvertToInternalStmt(stmt);
InlineStmt inline_stmt = ConvertToInlineStmt(internal_stmt);
return Translate(inline_stmt);
}
ir::DeviceAPI GetDeviceApi() const { return ir::DeviceAPI::Host; }
ir::Expr GetLoopSize(const LoopDescriptor& ld) const {
const auto& [_, loop_size] = ld.tuple();
CHECK(loop_size.Has<std::int64_t>());
return ir::Expr{IteratorInt(loop_size.Get<std::int64_t>())};
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const S0x& loop_type, const LoopDescriptor& ld) const {
ir::ForType for_type = ir::ForType::GPUBlock;
ir::BindInfo bind_info{for_type, 0, ir::DeviceAPI::GPU};
return std::make_tuple(for_type, ir::VectorizeInfo(), bind_info);
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const S0y& loop_type, const LoopDescriptor& ld) const {
ir::ForType for_type = ir::ForType::GPUBlock;
ir::BindInfo bind_info{for_type, 1, ir::DeviceAPI::GPU};
return std::make_tuple(for_type, ir::VectorizeInfo(), bind_info);
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const S0z& loop_type, const LoopDescriptor& ld) const {
ir::ForType for_type = ir::ForType::GPUBlock;
ir::BindInfo bind_info{for_type, 2, ir::DeviceAPI::GPU};
return std::make_tuple(for_type, ir::VectorizeInfo(), bind_info);
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const S1x& loop_type, const LoopDescriptor& ld) const {
ir::ForType for_type = ir::ForType::GPUThread;
ir::BindInfo bind_info{for_type, 0, ir::DeviceAPI::GPU};
return std::make_tuple(for_type, ir::VectorizeInfo(), bind_info);
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const S1y& loop_type, const LoopDescriptor& ld) const {
ir::ForType for_type = ir::ForType::GPUThread;
ir::BindInfo bind_info{for_type, 1, ir::DeviceAPI::GPU};
return std::make_tuple(for_type, ir::VectorizeInfo(), bind_info);
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const S1z& loop_type, const LoopDescriptor& ld) const {
ir::ForType for_type = ir::ForType::GPUThread;
ir::BindInfo bind_info{for_type, 2, ir::DeviceAPI::GPU};
return std::make_tuple(for_type, ir::VectorizeInfo(), bind_info);
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const Temporal& loop_type,
const LoopDescriptor& ld) const {
return std::make_tuple(
ir::ForType::Serial, ir::VectorizeInfo(), ir::BindInfo());
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const Vectorize& loop_type,
const LoopDescriptor& ld) const {
LOG(FATAL) << "Vectorize not supported yet";
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo>
GetForTypeAndInfoImpl(const Unroll& loop_type,
const LoopDescriptor& ld) const {
LOG(FATAL) << "Unroll not supported yet";
}
std::tuple<ir::ForType, ir::VectorizeInfo, ir::BindInfo> GetForTypeAndInfo(
const LoopDescriptor& ld) const {
const auto& [loop_type, _] = ld.tuple();
return std::visit(
[&](const auto& impl) { return GetForTypeAndInfoImpl(impl, ld); },
loop_type.variant());
}
ir::Expr Mul(const ir::Expr& a, std::int64_t b) const {
if (b == 1) {
return a;
} else {
ir::Expr b_expr{IteratorInt(b)};
return ir::Mul::Make(a, b_expr);
}
}
ir::Expr Accumulate(const std::vector<ir::Expr>& strided_exprs) const {
if (strided_exprs.size() == 0) {
LOG(FATAL) << "Dead code";
} else if (strided_exprs.size() == 1) {
return strided_exprs.at(0);
} else {
ir::Expr ret = strided_exprs.at(0);
for (int i = 1; i < strided_exprs.size(); ++i) {
ret = ir::Add::Make(ret, strided_exprs.at(i));
}
return ret;
}
LOG(FATAL) << "Dead code";
}
std::int64_t GetStride(const List<Constant>& dims, int start) const {
CHECK_GE(start, -1);
std::int64_t ret = 1;
for (int idx = start + 1; idx < dims->size(); ++idx) {
CHECK(dims->at(idx).Has<std::int64_t>());
ret *= dims->at(idx).Get<std::int64_t>();
}
return ret;
}
using IndexDotValueOfList = IndexDotValue<List<Value>, List<std::int64_t>>;
ir::Expr TranslateIndexDotValueOfList(const Value& value) const {
const auto& [list_value, dot_dims_value] =
value.Get<IndexDotValue<Value, Constant>>().tuple();
const auto& values = list_value.Get<List<Value>>();
const auto& dim_values = dot_dims_value.Get<List<Constant>>();
CHECK_EQ(values->size(), dim_values->size());
std::vector<ir::Expr> strided_exprs{};
for (std::size_t i = 0; i < values->size(); ++i) {
const auto& value_expr = TranslateTensorIterator(values->at(i));
const auto& stride_value = GetStride(dim_values, i);
strided_exprs.emplace_back(Mul(value_expr, stride_value));
}
return Accumulate(strided_exprs);
}
using ListGetItemOfUnDot =
ListGetItem<IndexUnDotValue<Value, List<std::int64_t>>, std::int64_t>;
ir::Expr TranslateListGetItemOfUnDot(const Value& value) const {
const auto& [undot_value, idx_value] =
value.Get<ListGetItem<Value, Constant>>().tuple();
const auto& [tensor_index_value, dims_value] =
undot_value.Get<IndexUnDotValue<Value, Constant>>().tuple();
ir::Expr tensor_index_expr = TranslateTensorIterator(tensor_index_value);
std::int64_t idx = idx_value.Get<std::int64_t>();
const auto& dims = dims_value.Get<List<Constant>>();
ir::Expr mod_operand{IteratorInt(GetStride(dims, idx - 1))};
ir::Expr div_operant{IteratorInt(GetStride(dims, idx))};
return ir::Div::Make(ir::Mod::Make(tensor_index_expr, mod_operand),
div_operant);
}
ir::Expr TranslateIterator(const Value& value) const {
const auto& iterator = value.Get<Iterator>();
return ir::Var("v_" + std::to_string(iterator.value().unique_id()));
}
ir::Expr TranslateTensorIterator(const Value& value) const {
if (Match<IndexDotValueOfList>(value)) {
return TranslateIndexDotValueOfList(value);
} else if (Match<ListGetItemOfUnDot>(value)) {
return TranslateListGetItemOfUnDot(value);
} else if (Match<Iterator>(value)) {
return TranslateIterator(value);
} else {
LOG(FATAL) << "Not supported yet! " << ToTxtString(value);
}
}
std::vector<ir::Expr> Translate(
const List<TensorIteratorExpr>& iterator_exprs) const {
std::vector<ir::Expr> ret{};
for (const auto& iterator_expr : *iterator_exprs) {
ret.emplace_back(TranslateTensorIterator(iterator_expr));
}
return ret;
}
MapExpr map_expr_;
const Node2LoweredFuncs* node2lowered_funcs_;
const common::Target target_;
TensorIteratorExpr4TensorT TensorIteratorExpr4Tensor;
LoopDescriptor4LoopIteratorT LoopDescriptor4LoopIterator;
};
} // namespace
ir::Expr MapExprToIr(const MapExprCtx& map_expr_ctx,
const common::Target& target) {
const auto& expr =
MapExprToIrTranslator(
map_expr_ctx.map_expr(), map_expr_ctx.node2lowered_funcs(), target)
.Translate();
VLOG(1) << "Finish MapExprToIr\n" << expr;
return expr;
}
} // namespace cinn::adt
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/adt/map_expr_ctx.h"
#include "paddle/cinn/ir/ir.h"
namespace cinn::common {
struct Target;
}
namespace cinn::adt {
ir::Expr MapExprToIr(const MapExprCtx& map_expr_ctx,
const common::Target& target);
}
......@@ -1077,6 +1077,31 @@ std::vector<ir::Tensor> TwoStepBlockReduceAny(const ir::Tensor& A,
Expr(false));
}
std::string CrossThreadReduceExternalFuncName(const ir::Expr& op,
const ir::Expr& tensor) {
CHECK_NOTNULL(tensor.as_tensor());
if (op.As<ir::Add>()) {
return "cinn_block_reduce_sum" +
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
} else if (op.As<ir::Mul>()) {
return "cinn_block_reduce_prod" +
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
} else if (op.As<ir::Max>()) {
return "cinn_block_reduce_max" +
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
} else if (op.As<ir::Min>()) {
return "cinn_block_reduce_min" +
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
} else if (op.As<ir::And>()) {
return "cinn_block_reduce_all_internal_shm";
} else if (op.As<ir::Or>()) {
return "cinn_block_reduce_any_internal_shm";
} else {
LOG(FATAL) << "Reduce type: " << op << " Not supported yet!";
}
return "";
}
} // namespace pe
} // namespace hlir
} // namespace cinn
......@@ -467,6 +467,11 @@ std::vector<ir::Tensor> TwoStepBlockReduceAny(
const std::vector<int>& axes,
const bool keep_dim,
const std::string& output_name = "T_Reduce_Any_out");
std::string CrossThreadReduceExternalFuncName(const ir::Expr& op,
const ir::Expr& tensor);
std::string Type2StrForReduce(common::Type type);
} // namespace pe
} // namespace hlir
} // namespace cinn
......@@ -31,7 +31,7 @@
#include "paddle/cinn/poly/isl_utils.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(cinn_use_cuda_vectorize);
PD_DECLARE_bool(cinn_use_cuda_vectorize);
namespace cinn {
namespace hlir {
namespace pe {
......
......@@ -5,6 +5,9 @@ gather_srcs(
SRCS
ir.cc
ir_base.cc
ir_visitor.cc
ir_printer.cc
ir_mutator.cc
function_definition.cc
buffer.cc
function_base.cc
......@@ -15,9 +18,13 @@ gather_srcs(
lowered_func.cc
intrinsic_ops.cc
layout.cc
schedule_block_graph.cc)
schedule_block_graph.cc
dim.cc)
add_subdirectory(ir_analyzer)
add_subdirectory(op)
add_subdirectory(test)
add_subdirectory(utils)
add_subdirectory(schedule)
add_subdirectory(group_schedule)
add_subdirectory(dy_schedule)
......@@ -16,8 +16,8 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace ir {
const _Dim_* Dim::operator->() const { return As<_Dim_>(); }
_Dim_* Dim::operator->() { return As<_Dim_>(); }
SymbolicDimOp _Dim_::GetSymbolicDim() const { return sym_dim; }
bool _Dim_::IsDynamic() const { return sym_dim.IsDynamic(); }
std::string _Dim_::GetSymbolName() const { return sym_dim.GetSymName(); }
int64_t _Dim_::GetRealDimSize() const { return sym_dim.GetDimSize(); }
Expr _Dim_::GetDimExpr() const { return dim_expr; }
Dim _Dim_::Make(const std::string& name, const SymbolicDimOp& sym_dim) {
auto* n = make_shared<_Dim_>();
n->name = name;
n->sym_dim = sym_dim;
if (sym_dim.IsDynamic()) {
n->dim_expr = Expr(Var(sym_dim.GetSymName(), type_of<std::string>()));
} else {
n->dim_expr = Expr(sym_dim.GetDimSize());
}
return Dim(n);
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir_base.h"
namespace cinn {
namespace ir {
struct _Dim_;
// This fake class is to pass the CI, and will be replaced by
// pir::shape::SymbolicDimOp when pir is completely integrated.
class SymbolicDimOp {
public:
const std::string GetSymName() const { return ""; }
int64_t GetDimSize() const { return 0; }
bool IsDynamic() const { return false; }
};
//! Wrapper for _Dim_
class Dim : public IrNodeRef {
public:
Dim() = default;
explicit Dim(IrNode* n) : IrNodeRef(n) {}
operator Expr() const { return Expr(ptr()); }
const _Dim_* operator->() const;
_Dim_* operator->();
};
/**
* Definition of _Dim_.
*/
struct _Dim_ : ExprNode<_Dim_> {
//! The name of this struct.
std::string name;
// (TODO: zhangzheng) Replace this fake class by pir::shape::SymbolicDimOp
SymbolicDimOp sym_dim;
Expr dim_expr;
SymbolicDimOp GetSymbolicDim() const;
bool IsDynamic() const;
std::string GetSymbolName() const;
int64_t GetRealDimSize() const;
Expr GetDimExpr() const;
static Dim Make(const std::string& name, const SymbolicDimOp& sym_dim);
static const IrNodeTy _node_type_ = IrNodeTy::_Dim_;
};
} // namespace ir
} // namespace cinn
core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
base.cc
compute_location.cc
for_type.cc
loop_transformation.cc
reduction.cc
storage.cc)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
namespace cinn {
namespace ir {
void DyScheduleImpl::MergeExprs() { CINN_NOT_IMPLEMENTED; }
bool DyScheduleImpl::HasBlock(const std::string& block_name) const {
CINN_NOT_IMPLEMENTED;
}
std::vector<Expr> DyScheduleImpl::GetLoops(const Expr& block) const {
CINN_NOT_IMPLEMENTED;
}
std::vector<Expr> DyScheduleImpl::GetLoops(
const std::string& block_name) const {
CINN_NOT_IMPLEMENTED;
}
std::vector<Expr> DyScheduleImpl::GetAllBlocks() const { CINN_NOT_IMPLEMENTED; }
std::vector<Expr> DyScheduleImpl::GetChildBlocks(const Expr& expr) const {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::GetBlock(const std::string& block_name) const {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::GetRootBlock(const Expr& expr) const {
CINN_NOT_IMPLEMENTED;
}
DeviceAPI DyScheduleImpl::GetDeviceAPI() const { CINN_NOT_IMPLEMENTED; }
void DyScheduleImpl::Annotate(const Expr& block,
const std::string& key,
const attr_t& value) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::Unannotate(Expr& block,
const std::string& key) { // NOLINT
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
const Expr& block_target) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::CopyTransformAndLoopInfo(
const std::string& block_name, const std::string& block_target_name) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::SampleCategorical(
utils::LinearRandomEngine::StateType* rand_seed,
const std::vector<int>& candidates,
const std::vector<float>& probs) {
CINN_NOT_IMPLEMENTED;
}
std::vector<Expr> DyScheduleImpl::SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
int n,
int max_innermost_factor) {
CINN_NOT_IMPLEMENTED;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
namespace cinn {
namespace ir {
void DyScheduleImpl::ComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::ReverseComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::ComputeInline(const Expr& schedule_block) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::ReverseComputeInline(const Expr& schedule_block) {
CINN_NOT_IMPLEMENTED;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
namespace cinn {
namespace ir {
void DyScheduleImpl::MutateForType(const Expr& loop,
ForType for_type,
int factor) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::Parallel(const Expr& loop) { CINN_NOT_IMPLEMENTED; }
void DyScheduleImpl::Vectorize(const Expr& loop, int factor) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::Unroll(const Expr& loop) { CINN_NOT_IMPLEMENTED; }
void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
CINN_NOT_IMPLEMENTED;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/schedule/schedule_base.h"
PD_DECLARE_int32(cinn_error_message_level);
namespace cinn {
namespace ir {
/**
* A struct helps to implement dynamic shape Schedule primitives.
*/
class DyScheduleImpl : public ScheduleBase {
public:
DyScheduleImpl() = delete;
explicit DyScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: ScheduleBase(module_expr, false, err_msg_level) {}
explicit DyScheduleImpl(ModuleExpr&& module_expr)
: ScheduleBase(std::move(module_expr)) {}
void MergeExprs();
bool HasBlock(const std::string& block_name) const;
std::vector<Expr> GetLoops(const Expr& block) const;
std::vector<Expr> GetLoops(const std::string& block_name) const;
std::vector<Expr> GetAllBlocks() const;
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
int n,
int max_innermost_factor);
Expr Fuse(const std::vector<Expr>& loops);
Expr Fuse(const std::string& block_name, const std::vector<int>& loops_index);
Expr Fuse(const Expr& block, const std::vector<int>& loops_index);
void ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops);
void SimpleComputeAt(const Expr& block, const Expr& loop);
void ReverseComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops);
Expr GetRootBlock(const Expr& expr) const;
Expr CacheRead(const Expr& block,
int read_buffer_index,
const std::string& memory_type);
Expr CacheWrite(const Expr& block,
int write_buffer_index,
const std::string& memory_type);
void SyncThreads(const Expr& ir_node, bool after_node = true);
void SetBuffer(Expr& block, // NOLINT
const std::string& memory_type,
bool fixed = false);
Expr Reorder(const std::vector<Expr>& loops);
Expr Reorder(const std::string& block_name,
const std::vector<int>& loops_index);
Expr Reorder(const Expr& block, const std::vector<int>& loops_index);
DeviceAPI GetDeviceAPI() const;
void MutateForType(const Expr& loop, ForType for_type, int factor = -1);
void Parallel(const Expr& loop);
void Vectorize(const Expr& loop, int factor);
void Unroll(const Expr& loop);
void ComputeInline(const Expr& schedule_block);
void ReverseComputeInline(const Expr& schedule_block);
void Bind(const Expr& loop, const std::string& thread_axis);
Expr Rfactor(const Expr& rf_loop, int rf_axis);
Expr FactorizeReduction(const Expr& rf_loop, int rf_axis);
Expr AddUnitLoop(const Expr& block) const;
void Annotate(const Expr& block, const std::string& key, const attr_t& value);
void Unannotate(Expr& block, const std::string& key); // NOLINT
void FlattenLoops(const std::vector<Expr>& loops,
const bool force_flat = false);
void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target);
void CopyTransformAndLoopInfo(const std::string& block_name,
const std::string& block_target_name);
Expr SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed,
const std::vector<int>& candidates,
const std::vector<float>& probs);
};
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
namespace cinn {
namespace ir {
std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
const std::vector<int>& factors) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::Fuse(const std::vector<Expr>& loops) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::Fuse(const std::string& block_name,
const std::vector<int>& loops_index) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::Fuse(const Expr& block,
const std::vector<int>& loops_index) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::Reorder(const std::vector<Expr>& loops) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::Reorder(const std::string& block_name,
const std::vector<int>& loops_index) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::Reorder(const Expr& block,
const std::vector<int>& loops_index) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::AddUnitLoop(const Expr& block) const {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
const bool force_flat) {
CINN_NOT_IMPLEMENTED;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
namespace cinn {
namespace ir {
Expr DyScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) {
CINN_NOT_IMPLEMENTED;
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
namespace cinn {
namespace ir {
Expr DyScheduleImpl::CacheRead(const Expr& block,
int read_buffer_index,
const std::string& memory_type) {
CINN_NOT_IMPLEMENTED;
}
Expr DyScheduleImpl::CacheWrite(const Expr& block,
int write_buffer_index,
const std::string& memory_type) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) {
CINN_NOT_IMPLEMENTED;
}
void DyScheduleImpl::SetBuffer(Expr& block, // NOLINT
const std::string& memory_type,
bool fixed) {
CINN_NOT_IMPLEMENTED;
}
} // namespace ir
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment