Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -25,9 +25,8 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
......@@ -65,6 +64,16 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name,
return LoweredFunc(n);
}
LoweredFunc _LoweredFunc_::Make(const std::string& name,
const std::vector<Argument>& args,
const Expr& body) {
auto* n = make_shared<_LoweredFunc_>();
n->name = name;
n->args = args;
n->body = body;
return LoweredFunc(n);
}
void _LoweredFunc_::CheckValid() const {
// check there is at least one output
int out_count = 0;
......@@ -83,7 +92,7 @@ std::vector<const Expr*> _LoweredFunc_::expr_fields() const { return {&body}; }
void _LoweredFunc_::PrepareCudaAxisInfoFromBody() {
std::set<Expr> bound_for_exprs =
ir::CollectIRNodes(body, [](const Expr* expr) {
ir::ir_utils::CollectIRNodes(body, [](const Expr* expr) {
const ir::For* for_expr = expr->As<ir::For>();
return for_expr != nullptr && for_expr->is_binded();
});
......@@ -209,8 +218,7 @@ void _LoweredFunc_::AllocTempBuffer() {}
void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) {
buffer_data_cast_exprs.clear();
// collect write.
optim::TensorWriteTeller write_teller;
write_teller.Collect(&body);
auto write_teller = ir::ir_utils::CollectTensorNeedsWrite(&body);
auto tensors = CollectAllTensorReference(with_expr_gen_tensor);
std::sort(tensors.begin(),
......@@ -224,7 +232,7 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) {
if (!tensor->buffer.defined()) continue;
Type value_type = tensor->type().ElementOf();
bool is_const = !write_teller.IsWrite(tensor->name);
bool is_const = !write_teller.count(tensor->name);
value_type.set_cpp_handle();
value_type.set_cpp_const(is_const);
Var variable = _Var_::Make(tensor->name, value_type);
......@@ -250,8 +258,7 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {
}
// collect write.
std::vector<Expr> res;
optim::TensorWriteTeller write_teller;
write_teller.Collect(&body);
auto write_teller = ir::ir_utils::CollectTensorNeedsWrite(&body);
auto tensors = CollectAllTensorReference();
std::sort(tensors.begin(),
......@@ -269,7 +276,7 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {
continue;
}
Type value_type = tensor->type().ElementOf();
bool is_const = !write_teller.IsWrite(tensor->name);
bool is_const = !write_teller.count(tensor->name);
value_type.set_cpp_handle();
value_type.set_cpp_const(is_const);
Var variable = _Var_::Make(tensor->name, value_type);
......@@ -406,11 +413,11 @@ std::vector<Tensor> _LoweredFunc_::CollectAllTensorReference(
bool with_expr_gen_tensor) const {
std::set<Expr> tensor_exprs =
with_expr_gen_tensor
? ir::CollectIRNodes(
? ir::ir_utils::CollectIRNodes(
body, [](const Expr* expr) { return expr->As<ir::_Tensor_>(); })
: ir::CollectIRNodesWithoutTensor(body, [](const Expr* expr) {
return expr->As<ir::_Tensor_>();
});
: ir::ir_utils::CollectIRNodesWithoutTensor(
body,
[](const Expr* expr) { return expr->As<ir::_Tensor_>(); });
std::vector<Tensor> tensors;
// remove the duplicate tensor by their name.
......
......@@ -30,8 +30,10 @@ class _LoweredFunc_;
* the function signature of generated code.
*/
struct Argument {
//! Input or output.
enum class IO { kInput = 0, kOutput = 1 };
//! kInput: arg is input
//! kOutput: arg is output
//! kUnknown: arg maybe input or output
enum class IO { kInput = 0, kOutput = 1, kUnknown = 2 };
IO io{IO::kInput};
......@@ -164,6 +166,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
const Expr& body,
const std::vector<ir::Buffer>& temp_bufs);
// A simple version of the make function method,
// regardless of the argument buffer information and IO information of
// Argument, after building the function to optimize the buffer through pass
static LoweredFunc Make(const std::string& name,
const std::vector<Argument>& args,
const Expr& body);
bool is_gpu_host() const { return cuda_axis_info.valid(); }
void Verify() const override {}
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/optimize.h"
......@@ -48,12 +49,19 @@ void Module::Builder::AddBuffer(ir::Buffer buffer) {
}
}
void Module::Builder::AddPredicate(ir::Expr predicate) {
module_->predicates.push_back(predicate);
}
void Module::Builder::Clear() {
module_->buffers.clear();
module_->functions.clear();
module_->submodules.clear();
module_->predicates.clear();
}
Target::Arch Module::Builder::GetTargetArch() { return module_->target.arch; }
Module Module::Builder::Build() {
if (module_->functions.empty()) {
VLOG(1) << "Module has no functions";
......@@ -61,7 +69,8 @@ Module Module::Builder::Build() {
auto res = ir::Module(module_.get());
return optim::Optimize(res, module_->target);
res = optim::Optimize(res, module_->target);
return res;
}
ir::_Module_ *Module::self() { return p_->as<ir::_Module_>(); }
......
......@@ -44,7 +44,9 @@ class Module : public ir::IrNodeRef {
void AddFunction(ir::LoweredFunc func);
void AddFunctionWithoutOptim(const ir::LoweredFunc& func);
void AddBuffer(ir::Buffer buffer);
void AddPredicate(ir::Expr predicate);
void Clear();
Target::Arch GetTargetArch();
Module Build();
......
......@@ -49,10 +49,12 @@ Operation ComputeOp::Make(const std::string &name,
n->reduce_axis = reduce_axis;
n->tag = tag;
n->attrs = attrs;
auto axis = common::GenDefaultAxis(domain.size());
std::vector<Expr> _axis;
for (auto &x : axis) _axis.push_back(x);
n->body = {handle(_axis)};
n->axis = common::GenDefaultAxis(domain.size());
std::vector<Expr> tmp_axis;
for (auto &x : n->axis) {
tmp_axis.push_back(x);
}
n->body = {handle(tmp_axis)};
n->reduce_axis = reduce_axis;
return Operation(n);
}
......
......@@ -105,6 +105,8 @@ struct BufferShareOp : public _Operation_ {
*/
struct ComputeOp : public _Operation_ {
using handle_t = std::function<Expr(const std::vector<Expr> &)>;
//! Var on each dimension
std::vector<Var> axis;
//! Var on each reduction axis, if the body is a Reduction.
std::vector<Var> reduce_axis;
//! Shape of the output.
......
cinn_proto_library(schedule_desc_proto SRCS schedule_desc.proto)
core_gather_headers()
gather_srcs(cinnapi_src SRCS ir_schedule.cc ir_schedule_util.cc
ir_schedule_error.cc schedule_desc.cc)
gather_srcs(
cinnapi_src
SRCS
schedule_base.cc
ir_schedule.cc
ir_schedule_util.cc
ir_schedule_error.cc
schedule_desc.cc)
foreach(header ${schedule_desc_proto_HDRS})
set(core_proto_includes
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Used in FactorizeReduction
#pragma once
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/error.h"
namespace cinn {
namespace ir {
// Create the new Reduction-Factorized tensor,
// only used for FactorizeReduction schedule primitive.
Tensor CreateRFTensor(const Tensor& original_tensor,
const Expr& rf_loop,
int rf_axis) {
std::string name = common::UniqName(original_tensor->name + "_rf");
std::vector<Expr> new_shape = original_tensor->shape;
new_shape.insert(new_shape.begin() + rf_axis, rf_loop.As<For>()->extent);
Tensor rf_tensor = _Tensor_::Make(name,
original_tensor->type(),
new_shape,
new_shape,
original_tensor->operation,
original_tensor->reduce_axis);
rf_tensor->WithBuffer("global", name, original_tensor->type());
return rf_tensor;
}
// Base class to create a new reduce block,
// only used for FactorizeReduction schedule primitive.
class ReduceBlockCreater {
public:
ReduceBlockCreater(const Expr& original_block,
const std::vector<Expr>& original_loops,
const Expr& rf_loop,
const Expr& original_update_stmt,
const ir::Tensor& rf_tensor,
bool is_rf_block)
: original_block_(original_block),
original_loops_(original_loops),
rf_loop_(rf_loop),
original_update_stmt_(original_update_stmt),
rf_tensor_(rf_tensor),
is_rf_block_(is_rf_block) {
const ScheduleBlockRealize* block_real =
original_block_.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_real);
num_block_iters_ = block_real->iter_values.size();
}
void CreateBlock() {
CreateRFIter();
for (int i = 0; i < num_block_iters_; ++i) {
CreateNormalIter(i);
}
CreateUpdateStmt();
std::string new_update_block_name =
original_block_.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
if (is_rf_block_) {
new_update_block_name = rf_tensor_->name;
}
std::string new_init_block_name =
ir::GenReduceInitTensorNameOf(new_update_block_name);
VLOG(5) << "new_init_block_name = " << new_init_block_name;
const ir::Tensor& real_tensor =
is_rf_block_
? rf_tensor_
: original_update_stmt_.As<ir::Store>()->tensor.as_tensor_ref();
Expr init_value = real_tensor->GetReduceInitVal();
const std::vector<Expr>& domain = real_tensor->domain_without_reduce_axis();
ir::Tensor init_tensor = lang::Compute(
domain,
[=](const std::vector<Expr>& axis) { return init_value; },
new_init_block_name);
init_tensor->Bind(real_tensor->buffer);
Expr init_stmt = ir::Store::Make(
init_tensor, init_value, new_update_stmt_.As<ir::Store>()->indices);
new_init_sch_block_ = ScheduleBlock::Make(
new_init_iter_vars_, {}, {}, new_init_block_name, init_stmt);
new_init_block_realize_ =
ScheduleBlockRealize::Make(new_init_iter_values_, new_init_sch_block_);
new_update_sch_block_ = ScheduleBlock::Make(
new_iter_vars_, {}, {}, new_update_block_name, new_update_stmt_);
new_update_block_realize_ =
ScheduleBlockRealize::Make(new_iter_values_, new_update_sch_block_);
VLOG(4) << "new_update_block_realize:\n" << new_update_block_realize_;
}
Expr CreateLoops() {
int num_loops = original_loops_.size();
std::vector<Expr> new_loops(num_loops);
Expr body = new_update_block_realize_;
bool has_add_init_block = false;
for (int i = num_loops - 1; i >= 0; --i) {
bool is_spatial_loop =
new_spatial_loop_var_names_.count(
original_loops_[i].As<For>()->loop_var->name) > 0;
bool is_rf_loop = rf_loop_.As<For>()->loop_var->name ==
original_loops_[i].As<For>()->loop_var->name;
// Skip non rf reduction loops of write back block.
if (!is_rf_block_ && !is_spatial_loop && !is_rf_loop) {
continue;
}
// Add reduce init block.
if (!has_add_init_block && is_spatial_loop) {
body = Block::Make({new_init_block_realize_, body});
has_add_init_block = true;
}
// Add loops
Var loop_var = ir_utils::IRCopy(original_loops_[i].As<For>()->loop_var);
Expr min = ir_utils::IRCopy(original_loops_[i].As<For>()->min);
Expr extent = ir_utils::IRCopy(original_loops_[i].As<For>()->extent);
body = For::Make(loop_var,
min,
extent,
original_loops_[i].As<For>()->for_type(),
original_loops_[i].As<For>()->device_api,
body,
original_loops_[i].As<For>()->vectorize_info(),
original_loops_[i].As<For>()->bind_info());
VLOG(5) << "new body:\n" << body;
}
VLOG(4) << "new loop nest:\n" << body;
return body;
}
private:
virtual void CreateRFIter() = 0;
virtual void CreateNormalIter(int idx) = 0;
virtual void CreateUpdateStmt() = 0;
public:
Var rf_var_;
std::vector<Expr> rf_tensor_access_indices_;
protected:
const Expr& original_block_;
const std::vector<Expr>& original_loops_;
const Expr& rf_loop_;
const Expr& original_update_stmt_;
const ir::Tensor& rf_tensor_;
std::map<Var, Expr, CompVar> original_indice2new_expr_;
int num_block_iters_;
bool is_rf_block_;
std::vector<Var> new_iter_vars_;
std::vector<Expr> new_iter_values_;
std::vector<Var> new_init_iter_vars_;
std::vector<Expr> new_init_iter_values_;
std::unordered_set<std::string> new_spatial_loop_var_names_;
Expr new_update_stmt_;
Expr new_update_sch_block_;
Expr new_update_block_realize_;
Expr new_init_sch_block_;
Expr new_init_block_realize_;
};
// Implement class for building Reduction-Factorized block,
// only used for FactorizeReduction schedule primitive.
class RFBlockCreater : public ReduceBlockCreater {
public:
RFBlockCreater(const Expr& original_block,
const std::vector<Expr>& original_loops,
const Expr& rf_loop,
const Expr& original_update_stmt,
const ir::Tensor& rf_tensor,
const std::map<Var, Expr, CompVar>& var2loops,
int rf_axis)
: ReduceBlockCreater(original_block,
original_loops,
rf_loop,
original_update_stmt,
rf_tensor,
true),
var2loops_(var2loops),
rf_axis_(rf_axis) {}
private:
void CreateRFIter() override {
std::string loop_var_name = rf_loop_.As<ir::For>()->loop_var->name;
std::string rf_var_name = "v" + loop_var_name;
rf_var_ = Var(rf_loop_.As<ir::For>()->min,
rf_loop_.As<ir::For>()->extent,
rf_var_name,
/* is_reduce = */ false);
loop_var2block_iters_[rf_loop_.As<ir::For>()->loop_var] = rf_var_;
new_iter_vars_.push_back(rf_var_);
new_iter_values_.push_back(rf_loop_.As<ir::For>()->loop_var);
new_init_iter_vars_.push_back(rf_var_);
new_init_iter_values_.push_back(rf_loop_.As<ir::For>()->loop_var);
new_spatial_loop_var_names_.insert(rf_loop_.As<ir::For>()->loop_var->name);
VLOG(4) << "create new_rf_var = " << rf_var_
<< ", with iter value = " << new_iter_values_.back();
}
void CreateNormalIter(int idx) override {
Var original_iter_var = original_block_.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars[idx];
Expr original_iter_value =
original_block_.As<ir::ScheduleBlockRealize>()->iter_values[idx];
// The original iter is either a spatial iter, or a reduction iter that
// doesn't touch the rf loop. In this case reuse the old iter var and its
// corresponding iter value.
if (!original_iter_var->is_reduce_axis) {
new_iter_vars_.push_back(original_iter_var);
new_iter_values_.push_back(original_iter_value);
new_init_iter_vars_.push_back(original_iter_var);
new_init_iter_values_.push_back(original_iter_value);
ir_utils::CollectIRNodesWithoutTensor(
original_iter_value, [&](const Expr* x) {
if (x->as_var()) {
new_spatial_loop_var_names_.insert(x->as_var()->name);
}
return false;
});
return;
} else if (!ContainVar({original_iter_value},
rf_loop_.As<ir::For>()->loop_var->name)) {
new_iter_vars_.push_back(original_iter_var);
new_iter_values_.push_back(original_iter_value);
return;
}
CHECK(original_iter_var->is_reduce_axis);
// This iter is a reduction iter and touches the rfactor loop. So we try to
// create a new iter for each loop var that appear in the original iter
// value.
std::vector<Var> vars_in_original_iter_values;
ir_utils::CollectIRNodesWithoutTensor(
original_iter_value, [&](const Expr* x) {
if (x->as_var()) {
vars_in_original_iter_values.push_back(x->as_var_ref());
}
return false;
});
for (const Var& loop_var : vars_in_original_iter_values) {
if (var2loops_.count(loop_var) == 0) {
continue;
}
Expr loop = var2loops_.at(loop_var);
if (loop_var2block_iters_.count(loop_var) == 0) {
Var new_iter_var(loop.As<ir::For>()->min,
loop.As<ir::For>()->extent,
"v" + loop_var->name,
/* is_reduce = */ true);
new_iter_vars_.push_back(new_iter_var);
new_iter_values_.emplace_back(loop_var);
loop_var2block_iters_[loop_var] = new_iter_var;
}
}
// Substitute the original iter values with new iter vars,
// and store the new iter values in original_indice2new_expr_,
// it will be used in Load/Store indices.
Expr new_iters = ir_utils::IRCopy(original_iter_value);
ReplaceExpr(&new_iters, loop_var2block_iters_);
original_indice2new_expr_[original_iter_var] = new_iters;
VLOG(4) << "original_indice2new_expr_[" << original_iter_var
<< "] = " << new_iters;
}
void CreateUpdateStmt() override {
rf_tensor_access_indices_ = original_update_stmt_.As<ir::Store>()->indices;
rf_tensor_access_indices_.insert(
rf_tensor_access_indices_.begin() + rf_axis_, rf_var_);
Expr original_store_body = original_update_stmt_.As<ir::Store>()->value;
Expr new_store_body = ir_utils::IRCopy(original_store_body);
#define REPLACE_RF_TENSOR(Op) \
if (new_store_body.As<Op>()) { \
auto* node = new_store_body.As<Op>(); \
CHECK(node); \
auto& operand = node->a(); \
operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \
}
REPLACE_RF_TENSOR(Add)
REPLACE_RF_TENSOR(Mul)
REPLACE_RF_TENSOR(Max)
REPLACE_RF_TENSOR(Min)
REPLACE_RF_TENSOR(And)
REPLACE_RF_TENSOR(Or)
REPLACE_RF_TENSOR(LT)
REPLACE_RF_TENSOR(LE)
REPLACE_RF_TENSOR(GT)
REPLACE_RF_TENSOR(GE)
#undef REPLACE_RF_TENSOR
new_update_stmt_ =
ir::Store::Make(rf_tensor_, new_store_body, rf_tensor_access_indices_);
ReplaceExpr(&new_update_stmt_, original_indice2new_expr_);
VLOG(4) << "new_update_stmt of rf block: \n" << new_update_stmt_;
}
private:
const std::map<Var, Expr, CompVar>& var2loops_;
int rf_axis_;
std::map<Var, Expr, CompVar> loop_var2block_iters_;
};
// Implement class for building Writing-Back block,
// only used for FactorizeReduction schedule primitive.
class RBBlockCreater : public ReduceBlockCreater {
public:
RBBlockCreater(const Expr& original_block,
const std::vector<Expr>& original_loops,
const Expr& rf_loop,
const Expr& original_update_stmt,
const ir::Tensor& rf_tensor,
const std::vector<Expr>& rf_tensor_access_indices,
const Var& rf_block_rf_iter_var)
: ReduceBlockCreater(original_block,
original_loops,
rf_loop,
original_update_stmt,
rf_tensor,
false),
rf_tensor_access_indices_(rf_tensor_access_indices),
rf_block_rf_iter_var_(rf_block_rf_iter_var) {}
private:
void CreateRFIter() override {
std::string loop_var_name = rf_loop_.As<ir::For>()->loop_var->name;
std::string rf_var_name = "v" + loop_var_name;
rf_var_ = Var(rf_loop_.As<ir::For>()->min,
rf_loop_.As<ir::For>()->extent,
rf_var_name,
/* is_reduce = */ true);
new_iter_vars_.push_back(rf_var_);
new_iter_values_.push_back(rf_loop_.As<ir::For>()->loop_var);
original_indice2new_expr_[rf_block_rf_iter_var_] = Expr(rf_var_);
VLOG(4) << "create new_rf_var = " << rf_var_
<< ", with iter value = " << new_iter_values_.back();
}
void CreateNormalIter(int idx) override {
Var original_iter_var = original_block_.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars[idx];
Expr original_iter_value =
original_block_.As<ir::ScheduleBlockRealize>()->iter_values[idx];
if (!original_iter_var->is_reduce_axis) {
new_iter_vars_.push_back(original_iter_var);
new_iter_values_.push_back(original_iter_value);
new_init_iter_vars_.push_back(original_iter_var);
new_init_iter_values_.push_back(original_iter_value);
ir_utils::CollectIRNodesWithoutTensor(
original_iter_value, [&](const Expr* x) {
if (x->as_var()) {
new_spatial_loop_var_names_.insert(x->as_var()->name);
}
return false;
});
// original_indice2new_expr_[original_iter_var] = new_iter_vars_.back();
VLOG(4) << "create new iter var = " << new_iter_vars_.back()
<< ", with iter value = " << new_iter_values_.back();
}
}
void CreateUpdateStmt() override {
Expr original_store_body = original_update_stmt_.As<ir::Store>()->value;
Expr new_store_body = ir_utils::IRCopy(original_store_body);
#define REPLACE_RF_TENSOR(Op) \
if (new_store_body.As<Op>()) { \
auto* node = new_store_body.As<Op>(); \
CHECK(node); \
auto& operand = node->b(); \
operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \
}
REPLACE_RF_TENSOR(Add)
REPLACE_RF_TENSOR(Mul)
REPLACE_RF_TENSOR(Max)
REPLACE_RF_TENSOR(Min)
REPLACE_RF_TENSOR(And)
REPLACE_RF_TENSOR(Or)
REPLACE_RF_TENSOR(LT)
REPLACE_RF_TENSOR(LE)
REPLACE_RF_TENSOR(GT)
REPLACE_RF_TENSOR(GE)
#undef REPLACE_RF_TENSOR
Expr original_store_tensor = original_update_stmt_.As<ir::Store>()->tensor;
std::vector<Expr> original_store_indices =
original_update_stmt_.As<ir::Store>()->indices;
new_update_stmt_ = ir::Store::Make(
original_store_tensor, new_store_body, original_store_indices);
ReplaceExpr(&new_update_stmt_, original_indice2new_expr_);
VLOG(4) << "new_update_stmt of write back block: \n" << new_update_stmt_;
}
private:
const std::vector<Expr>& rf_tensor_access_indices_;
const Var& rf_block_rf_iter_var_;
};
} // namespace ir
} // namespace cinn
......@@ -27,56 +27,46 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/dev_info_manager.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/factorize_reduction.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"
DECLARE_int32(cinn_error_message_level);
PD_DECLARE_int32(cinn_error_message_level);
namespace cinn {
namespace ir {
/**
* A struct helps to implement Schedule primitives.
* A struct helps to implement static shape Schedule primitives.
*/
class ScheduleImpl {
class StScheduleImpl : public ScheduleBase {
public:
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<utils::ErrorMessageLevel>(
FLAGS_cinn_error_message_level || static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}
//! Set the debug flag.
void SetDebugFlag(bool debug_flag) { debug_flag_ = debug_flag; }
//! Get the ModuleExpr stored in ScheduleImpl.
const ModuleExpr& GetModule() const { return module_expr_; }
StScheduleImpl() = delete;
explicit StScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: ScheduleBase(module_expr, false, err_msg_level) {}
explicit StScheduleImpl(ModuleExpr&& module_expr)
: ScheduleBase(std::move(module_expr)) {}
void MergeExprs();
void SetExprs(const std::vector<Expr>& exprs) {
module_expr_.SetExprs(exprs);
}
bool HasBlock(const std::string& block_name) const;
std::vector<Expr> GetLoops(const Expr& block) const;
std::vector<Expr> GetLoops(const std::string& block_name) const;
std::vector<Expr> GetAllBlocks() const;
......@@ -120,6 +110,7 @@ class ScheduleImpl {
void ReverseComputeInline(const Expr& schedule_block);
void Bind(const Expr& loop, const std::string& thread_axis);
Expr Rfactor(const Expr& rf_loop, int rf_axis);
Expr FactorizeReduction(const Expr& rf_loop, int rf_axis);
Expr AddUnitLoop(const Expr& block) const;
void Annotate(const Expr& block, const std::string& key, const attr_t& value);
void Unannotate(Expr& block, const std::string& key); // NOLINT
......@@ -131,14 +122,32 @@ class ScheduleImpl {
Expr SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed,
const std::vector<int>& candidates,
const std::vector<float>& probs);
};
private:
void Replace(const Expr& src_sref, const Expr& tgt_stmt);
std::unique_ptr<ScheduleBase> ScheduleBase::Make(
const ModuleExpr& module_expr,
bool debug_flag,
utils::ErrorMessageLevel err_msg_level,
bool is_dynamic) {
if (is_dynamic) {
return std::make_unique<DyScheduleImpl>(
module_expr, debug_flag, err_msg_level);
} else {
return std::make_unique<StScheduleImpl>(
module_expr, debug_flag, err_msg_level);
}
return nullptr;
}
ModuleExpr module_expr_;
bool debug_flag_{false};
utils::ErrorMessageLevel err_msg_level_ = utils::ErrorMessageLevel::kGeneral;
};
std::unique_ptr<ScheduleBase> ScheduleBase::Make(ModuleExpr&& module_expr,
bool is_dynamic) {
if (is_dynamic) {
return std::make_unique<DyScheduleImpl>(std::move(module_expr));
} else {
return std::make_unique<StScheduleImpl>(std::move(module_expr));
}
return nullptr;
}
/** \brief A macro that guards the beginning of each implementation of schedule
*/
......@@ -156,8 +165,8 @@ class ScheduleImpl {
CINN_THROW(err_hanlder.FormatErrorMessage(err_msg_level)); \
}
std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
const std::vector<int>& factors) {
std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
const std::vector<int>& factors) {
CHECK(loop.As<ir::For>())
<< "Expr param of Split must be For node! Please check.";
auto* for_node = loop.As<ir::For>();
......@@ -189,7 +198,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
new_loop_vars.push_back(temp_var);
}
substitute_value = common::AutoSimplify(substitute_value);
Expr new_node = optim::IRCopy(for_node->body);
Expr new_node = ir::ir_utils::IRCopy(for_node->body);
ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
std::vector<Expr> splited_loops;
splited_loops.resize(processed_factors.size());
......@@ -213,7 +222,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
return splited_loops;
}
Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
Expr StScheduleImpl::Fuse(const std::vector<Expr>& loops) {
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
std::vector<const ir::For*> for_nodes;
std::vector<Var> loop_vars;
......@@ -252,7 +261,7 @@ Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
}
substitute_value[0] = fused_expr;
Expr fused_body = optim::IRCopy(for_nodes.back()->body);
Expr fused_body = ir::ir_utils::IRCopy(for_nodes.back()->body);
ReplaceExpr(&fused_body, loop_vars, substitute_value);
optim::Simplify(&fused_body);
Expr fused_extent(1);
......@@ -274,8 +283,8 @@ Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
return new_stmt;
}
Expr ScheduleImpl::Fuse(const std::string& block_name,
const std::vector<int>& loops_index) {
Expr StScheduleImpl::Fuse(const std::string& block_name,
const std::vector<int>& loops_index) {
std::vector<Expr> all_loops = this->GetLoops(block_name);
std::vector<Expr> loops_expr;
loops_expr.reserve(loops_index.size());
......@@ -293,8 +302,8 @@ Expr ScheduleImpl::Fuse(const std::string& block_name,
return this->Fuse(loops_expr);
}
Expr ScheduleImpl::Fuse(const Expr& block,
const std::vector<int>& loops_index) {
Expr StScheduleImpl::Fuse(const Expr& block,
const std::vector<int>& loops_index) {
std::vector<Expr> all_loops = this->GetLoops(block);
std::vector<Expr> loops_expr;
loops_expr.reserve(loops_index.size());
......@@ -312,16 +321,16 @@ Expr ScheduleImpl::Fuse(const Expr& block,
return this->Fuse(loops_expr);
}
void ScheduleImpl::MutateForType(const Expr& loop,
ForType for_type,
int factor) {
void StScheduleImpl::MutateForType(const Expr& loop,
ForType for_type,
int factor) {
auto* for_node = loop.As<ir::For>();
CHECK(for_node) << "loop param must be For node! Please check.";
CHECK(for_node->is_serial())
<< "loop is not serial, current forloop type is "
<< static_cast<int>(for_node->for_type()) << ", and it cannot become "
<< static_cast<int>(for_type);
auto loop_copy = optim::IRCopy(loop);
auto loop_copy = ir::ir_utils::IRCopy(loop);
auto* new_for_node = loop_copy.As<ir::For>();
CHECK(new_for_node);
new_for_node->set_for_type(for_type);
......@@ -335,20 +344,21 @@ void ScheduleImpl::MutateForType(const Expr& loop,
this->Replace(loop, loop_copy);
}
void ScheduleImpl::Parallel(const Expr& loop) {
void StScheduleImpl::Parallel(const Expr& loop) {
MutateForType(loop, ForType::Parallel);
}
void ScheduleImpl::Vectorize(const Expr& loop, int factor) {
void StScheduleImpl::Vectorize(const Expr& loop, int factor) {
CHECK_GT(factor, 0) << "vectorize factor should be more than 0";
MutateForType(loop, ForType::Vectorized, factor);
}
void ScheduleImpl::Unroll(const Expr& loop) {
void StScheduleImpl::Unroll(const Expr& loop) {
MutateForType(loop, ForType::Unrolled);
}
void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
void StScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
#ifdef CINN_WITH_CUDA
static std::set<std::string> thread_axes = {"blockIdx.x",
"blockIdx.y",
"blockIdx.z",
......@@ -358,11 +368,24 @@ void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
CHECK(thread_axes.count(thread_axis))
<< "thread_axis " << thread_axis << " is not supported";
int offset = thread_axis.back() - 'x';
auto cur_dev_info =
common::DevInfoMgr<common::Target::Arch::NVGPU>::GetDevInfo(0);
const std::array<int, 3> kMaxBlockDims = cur_dev_info->GetMaxBlockDims();
const std::array<int, 3> kMaxGridDims = cur_dev_info->GetMaxGridDims();
auto check_offset = [&](const char& c) -> bool {
auto extent = loop.As<ir::For>()->extent.as_int32();
return extent <= (c == 'b' ? kMaxGridDims[offset] : kMaxBlockDims[offset]);
};
if (thread_axis[0] == 'b') {
CHECK(check_offset(thread_axis[0]))
<< "Invalid Bind! The extent of loop is out of range on grid size!\n";
MutateForType(loop, ForType::GPUBlock, offset);
} else {
CHECK(check_offset(thread_axis[0]))
<< "Invalid Bind! The extent of loop is out of range on block size!\n";
MutateForType(loop, ForType::GPUThread, offset);
}
#endif
}
// The struct used to mutate new rfactor forloop and its' schedule block.
......@@ -674,7 +697,7 @@ struct RfCreater : public ir::IRMutator<> {
CHECK(root_realize);
auto root_block = root_realize->schedule_block.As<ScheduleBlock>();
CHECK(root_block);
Expr root_loop = optim::IRCopy(root_block->body);
Expr root_loop = ir::ir_utils::IRCopy(root_block->body);
if (auto block = root_loop.As<Block>()) {
CHECK_EQ(block->stmts.size(), 1U)
<< "rfactor root should only have one block stmt";
......@@ -685,13 +708,13 @@ struct RfCreater : public ir::IRMutator<> {
auto rf_for = rf_loop_.As<For>();
CHECK(rf_for);
// create new rfactor forloops
Expr new_rf_forloop = optim::IRCopy(root_loop);
Expr new_rf_forloop = ir::ir_utils::IRCopy(root_loop);
RfMutator rf_mutator(rf_loop_, rf_axis_);
rf_mutator(&new_rf_forloop);
VLOG(3) << "After RfMutator, new rf_forloop is\n" << new_rf_forloop;
auto new_rf_tensor = rf_mutator.GetNewRfTensor();
// create final write-back forloops
Expr final_forloop = optim::IRCopy(root_loop);
Expr final_forloop = ir::ir_utils::IRCopy(root_loop);
FinalMutator final_mutator(rf_loop_, rf_axis_, new_rf_tensor);
final_mutator(&final_forloop);
VLOG(3) << "After FinalMuator, final write-back forloop is\n"
......@@ -707,7 +730,7 @@ struct RfCreater : public ir::IRMutator<> {
int rf_axis_;
};
Expr ScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) {
Expr StScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) {
CHECKRfactorValidation(rf_loop, rf_axis);
// get root ScheduleBlockRealize
Expr root = GetRootBlock(rf_loop);
......@@ -717,11 +740,84 @@ Expr ScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) {
return rf_create.CreateRfAllStmts();
}
Expr StScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) {
std::string primitive = "FactorizeReduction";
// Get child block of the rf_loop and check.
std::vector<Expr> blocks = GetChildBlocks(rf_loop);
if (blocks.size() != 1) {
std::ostringstream os;
os << "The rf_loop is required to have only one child block, but got "
<< blocks.size() << std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), this->module_expr_);
}
Expr original_block = blocks.at(0);
Expr root_block = GetRootBlock(original_block);
// TODO(BiynXu): Add CheckReductionBlock()
// Collect the loops of the block.
// Construct a map from loop var names to corresponding loops.
std::vector<Expr> original_loops = this->GetLoops(original_block);
CHECK_GT(original_loops.size(), 0);
VLOG(3) << "before FactorizeReduction, original computational body of the "
"reduction is:\n"
<< original_loops[0];
std::map<Var, Expr, CompVar> var2loops;
for (const Expr& loop : original_loops) {
var2loops[loop.As<For>()->loop_var] = loop;
}
// Get original stmt of reduction update and original store tensor.
Expr original_update_body = original_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body;
Expr original_update_stmt;
CHECK(original_update_body.As<Block>() || original_update_body.As<Store>());
if (original_update_body.As<Block>()) {
CHECK_EQ(original_update_body.As<Block>()->stmts.size(), 1);
original_update_stmt = original_update_body.As<Block>()->stmts[0];
} else if (original_update_body.As<Store>()) {
original_update_stmt = original_update_body;
}
Tensor original_tensor =
original_update_stmt.As<Store>()->tensor.as_tensor_ref();
// Create new blocks and loops.
Tensor rf_tensor = CreateRFTensor(original_tensor, rf_loop, rf_axis);
RFBlockCreater rf_block_creater(original_block,
original_loops,
rf_loop,
original_update_stmt,
rf_tensor,
var2loops,
rf_axis);
rf_block_creater.CreateBlock();
RBBlockCreater wb_block_creater(original_block,
original_loops,
rf_loop,
original_update_stmt,
rf_tensor,
rf_block_creater.rf_tensor_access_indices_,
rf_block_creater.rf_var_);
wb_block_creater.CreateBlock();
Expr rf_body = rf_block_creater.CreateLoops();
Expr wb_body = wb_block_creater.CreateLoops();
Expr new_computational_body = Block::Make({rf_body, wb_body});
// Replace and update the AST.
this->Replace(original_loops[0], new_computational_body);
VLOG(3) << "After FactorizeReduction, new computational body of the "
"reduction is:\n"
<< new_computational_body;
return rf_tensor;
}
struct CacheReadRewriter : public ir::IRMutator<> {
public:
static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
CacheReadRewriter rewriter(root, info);
Expr new_root = optim::IRCopy(root);
Expr new_root = ir::ir_utils::IRCopy(root);
rewriter(&new_root);
return new_root;
}
......@@ -762,12 +858,12 @@ struct CacheWriteRewriter : public ir::IRMutator<> {
public:
static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
CacheWriteRewriter rewriter(root, info);
Expr new_root = optim::IRCopy(root);
Expr new_root = ir::ir_utils::IRCopy(root);
rewriter.mutate_cache_block = true;
rewriter(&info->cache_block);
rewriter.mutate_cache_block = false;
rewriter(&new_root);
auto find_tensor = ir::CollectIRNodesWithoutTensor(
auto find_tensor = ir::ir_utils::CollectIRNodesWithoutTensor(
new_root,
[&](const Expr* x) {
return x->As<Store>() &&
......@@ -775,7 +871,7 @@ struct CacheWriteRewriter : public ir::IRMutator<> {
},
true);
if (!find_tensor.empty()) {
auto find_store = ir::CollectIRNodesWithoutTensor(
auto find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
(*find_tensor.begin()), [&](const Expr* x) {
return x->As<Load>() &&
(x->As<Load>()->tensor == Expr(info->write_tensor));
......@@ -862,17 +958,14 @@ struct ChangeBodyToBlock : public ir::IRMutator<> {
}
};
DeviceAPI ScheduleImpl::GetDeviceAPI() const {
DeviceAPI StScheduleImpl::GetDeviceAPI() const {
auto exprs = this->GetModule().GetExprs();
auto find_for_nodes = ir::CollectIRNodesWithoutTensor(
exprs.front(), [&](const Expr* x) { return x->As<ir::For>(); }, true);
CHECK(!find_for_nodes.empty());
return (*find_for_nodes.begin()).As<ir::For>()->device_api;
return analyzer::GetDeviceAPI(exprs);
}
Expr ScheduleImpl::CacheRead(const Expr& block,
int read_tensor_index,
const std::string& memory_type) {
Expr StScheduleImpl::CacheRead(const Expr& block,
int read_tensor_index,
const std::string& memory_type) {
CHECK(block.As<ScheduleBlockRealize>());
auto root = GetRootBlock(block);
ChangeBodyToBlock::Change(&root);
......@@ -898,9 +991,9 @@ Expr ScheduleImpl::CacheRead(const Expr& block,
return new_block;
}
Expr ScheduleImpl::CacheWrite(const Expr& block,
int write_buffer_index,
const std::string& memory_type) {
Expr StScheduleImpl::CacheWrite(const Expr& block,
int write_buffer_index,
const std::string& memory_type) {
CHECK(block.As<ScheduleBlockRealize>());
auto root = GetRootBlock(block);
ChangeBodyToBlock::Change(&root);
......@@ -925,7 +1018,7 @@ Expr ScheduleImpl::CacheWrite(const Expr& block,
->schedule_block.As<ScheduleBlock>()
->body);
auto find_cache_block = ir::CollectIRNodesWithoutTensor(
auto find_cache_block = ir::ir_utils::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
......@@ -937,9 +1030,10 @@ Expr ScheduleImpl::CacheWrite(const Expr& block,
CHECK(info.write_tensor->buffer.defined());
// Replace buffer
auto all_tensors = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->as_tensor() && x->as_tensor()->buffer.defined();
});
auto all_tensors =
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->as_tensor() && x->as_tensor()->buffer.defined();
});
for (auto i : all_tensors) {
if (i.as_tensor()->name != info.write_tensor->name &&
......@@ -1007,7 +1101,7 @@ struct InsertExpr : public ir::IRMutator<> {
bool after_node_;
};
void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) {
void StScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) {
CHECK(ir_node.As<ScheduleBlockRealize>() || ir_node.As<ir::For>());
auto root = GetRootBlock(ir_node);
ChangeBodyToBlock::Change(&root);
......@@ -1016,60 +1110,7 @@ void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) {
return;
}
/**
* Replace a For node to another For node.
* @param src_sref The For node to be changed.
* @param tgt_stmt The For node we want.
*/
void ScheduleImpl::Replace(const Expr& src_sref, const Expr& tgt_stmt) {
CHECK(src_sref.As<ir::For>() || src_sref.As<ir::Block>() ||
src_sref.As<ir::ScheduleBlockRealize>());
CHECK(tgt_stmt.As<ir::For>() || tgt_stmt.As<ir::Block>() ||
tgt_stmt.As<ir::ScheduleBlockRealize>());
if (src_sref == tgt_stmt) {
return;
}
struct ForLoopMutator : public ir::IRMutator<> {
ForLoopMutator(const Expr& source, const Expr& target)
: source_(source), target_(target) {}
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::For* op, Expr* expr) override {
if (*expr == source_) {
*expr = target_;
return;
}
ir::IRMutator<>::Visit(op, expr);
}
void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override {
if (*expr == source_) {
*expr = target_;
return;
}
ir::IRMutator<>::Visit(op, expr);
}
void Visit(const ir::Block* op, Expr* expr) override {
if (*expr == source_) {
*expr = target_;
return;
}
ir::IRMutator<>::Visit(op, expr);
}
const Expr& source_;
const Expr& target_;
};
auto exprs = module_expr_.GetExprs();
ForLoopMutator mutator(src_sref, tgt_stmt);
for (auto& i : exprs) {
mutator(&i);
}
}
Expr ScheduleImpl::Reorder(const std::vector<Expr>& loops) {
Expr StScheduleImpl::Reorder(const std::vector<Expr>& loops) {
if (loops.size() <= 1) {
return Expr{nullptr};
}
......@@ -1088,8 +1129,8 @@ Expr ScheduleImpl::Reorder(const std::vector<Expr>& loops) {
return new_loop;
}
Expr ScheduleImpl::Reorder(const std::string& block_name,
const std::vector<int>& loops_index) {
Expr StScheduleImpl::Reorder(const std::string& block_name,
const std::vector<int>& loops_index) {
std::vector<Expr> all_loops = this->GetLoops(block_name);
std::vector<Expr> loops_expr;
loops_expr.reserve(loops_index.size());
......@@ -1102,8 +1143,8 @@ Expr ScheduleImpl::Reorder(const std::string& block_name,
return this->Reorder(loops_expr);
}
Expr ScheduleImpl::Reorder(const Expr& block,
const std::vector<int>& loops_index) {
Expr StScheduleImpl::Reorder(const Expr& block,
const std::vector<int>& loops_index) {
std::vector<Expr> all_loops = this->GetLoops(block);
std::vector<Expr> loops_expr;
loops_expr.reserve(loops_index.size());
......@@ -1116,25 +1157,9 @@ Expr ScheduleImpl::Reorder(const Expr& block,
return this->Reorder(loops_expr);
}
Expr ScheduleImpl::GetRootBlock(const Expr& expr) const {
Expr StScheduleImpl::GetRootBlock(const Expr& expr) const {
auto exprs = this->GetModule().GetExprs();
for (auto& it_expr : exprs) {
auto find_expr = ir::CollectIRNodesWithoutTensor(
it_expr,
[&](const Expr* x) {
return x->node_type() == expr.node_type() && *x == expr;
},
true);
if (!find_expr.empty()) {
CHECK(it_expr.As<ir::Block>());
CHECK_EQ(it_expr.As<ir::Block>()->stmts.size(), 1U);
CHECK(it_expr.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>());
return it_expr.As<ir::Block>()->stmts[0];
}
}
LOG(FATAL) << "Didn't find expr \n"
<< expr << "in ScheduleImpl:\n"
<< exprs[0];
return analyzer::GetRootBlock(exprs, expr);
}
// The struct used to reconstruct the new For node to replace the old For node.
......@@ -1193,25 +1218,26 @@ struct LoopReconstructor : public ir::IRMutator<> {
loop_.As<ir::For>()->device_api,
std::move(loop_body));
}
new_loop_ = optim::IRCopy(loop_);
new_loop_ = ir::ir_utils::IRCopy(loop_);
// Replace the copied Tensor object with the original Tensor object,
// to ensure that the same Tensor in a AST is the same object.
std::unordered_map<std::string, ir::Expr> tensors_map;
ir::CollectIRNodesWithoutTensor(loop_, [&tensors_map](const Expr* x) {
if (x->as_tensor()) {
tensors_map.insert({x->as_tensor()->name, *x});
return true;
}
return false;
});
auto find_store = ir::CollectIRNodesWithoutTensor(
ir::ir_utils::CollectIRNodesWithoutTensor(
loop_, [&tensors_map](const Expr* x) {
if (x->as_tensor()) {
tensors_map.insert({x->as_tensor()->name, *x});
return true;
}
return false;
});
auto find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
new_loop_, [](const Expr* x) { return x->As<ir::Store>(); });
for (auto store : find_store) {
store.As<ir::Store>()->tensor =
tensors_map.at(store.As<ir::Store>()->tensor.as_tensor()->name);
}
auto find_load = ir::CollectIRNodesWithoutTensor(
auto find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
new_loop_, [](const Expr* x) { return x->As<ir::Load>(); });
for (auto load : find_load) {
load.As<ir::Load>()->tensor =
......@@ -1271,11 +1297,11 @@ struct FixLocalBufferSize : public ir::IRMutator<> {
std::string tensor_name_;
};
void ScheduleImpl::SetBuffer(Expr& block,
const std::string& memory_type,
bool fixed) {
void StScheduleImpl::SetBuffer(Expr& block,
const std::string& memory_type,
bool fixed) {
CHECK(block.As<ir::ScheduleBlockRealize>());
auto find_tensor = ir::CollectIRNodesWithoutTensor(
auto find_tensor = ir::ir_utils::CollectIRNodesWithoutTensor(
block, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
CHECK_EQ(find_tensor.size(), 1U)
<< "One block should only have one Store node!(except for root block)";
......@@ -1286,7 +1312,7 @@ void ScheduleImpl::SetBuffer(Expr& block,
auto exprs = this->GetModule().GetExprs();
for (auto& it_expr : exprs) {
auto find_tensor =
ir::CollectIRNodesWithoutTensor(it_expr, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(it_expr, [&](const Expr* x) {
return x->as_tensor() &&
(x->as_tensor()->name == tensor.as_tensor_ref()->name ||
x->as_tensor()->name ==
......@@ -1308,7 +1334,7 @@ void ScheduleImpl::SetBuffer(Expr& block,
}
}
void ScheduleImpl::MergeExprs() {
void StScheduleImpl::MergeExprs() {
auto exprs = this->GetModule().GetExprs();
if (exprs.size() == 1U) return;
CHECK(exprs[0].As<ir::Block>());
......@@ -1328,7 +1354,7 @@ void ScheduleImpl::MergeExprs() {
->body);
VLOG(3) << "Before merging, exprs[0] is : " << exprs[0];
for (int i = 1; i < exprs.size(); ++i) {
auto root_block = ir::CollectIRNodesWithoutTensor(
auto root_block = ir::ir_utils::CollectIRNodesWithoutTensor(
exprs[i],
[&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
......@@ -1358,9 +1384,9 @@ void ScheduleImpl::MergeExprs() {
this->SetExprs(exprs);
}
void ScheduleImpl::ComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) {
void StScheduleImpl::ComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(loop.As<ir::For>());
Expr root = this->GetRootBlock(block);
......@@ -1386,7 +1412,7 @@ void ScheduleImpl::ComputeAt(const Expr& block,
VLOG(3) << "After SimpleComputeAt, ir is:\n" << reconstructor.new_loop_;
}
void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
void StScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(loop.As<ir::For>());
std::vector<Expr> block_loops = this->GetLoops(block);
......@@ -1429,15 +1455,15 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
}
Expr result = loops.size() < block_loops.size()
? optim::IRCopy(block_loops[loops.size()])
: optim::IRCopy(this_block);
Expr new_loop = optim::IRCopy(this_loop);
? ir::ir_utils::IRCopy(block_loops[loops.size()])
: ir::ir_utils::IRCopy(this_block);
Expr new_loop = ir::ir_utils::IRCopy(this_loop);
// Get the body of block_loop under the same loops
auto body = block_loops.at(loops.size() - 1).As<ir::For>()->body;
// collect if
auto if_checker = [](const Expr* x) { return x->As<ir::IfThenElse>(); };
auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker);
auto if_set = ir::ir_utils::CollectIRNodesWithoutTensor(body, if_checker);
for (auto if_expr : if_set) {
auto checker = [block_name](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
......@@ -1445,7 +1471,8 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
->schedule_block.As<ScheduleBlock>()
->name == block_name;
};
if (ir::CollectIRNodesWithoutTensor(if_expr, checker, true).size() > 0) {
if (ir::ir_utils::CollectIRNodesWithoutTensor(if_expr, checker, true)
.size() > 0) {
result =
IfThenElse::Make(if_expr.As<ir::IfThenElse>()->condition, result);
break;
......@@ -1498,9 +1525,9 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
VLOG(3) << "After SimpleComputeAt, ir is:\n" << new_loop;
}
void ScheduleImpl::ReverseComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) {
void StScheduleImpl::ReverseComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(loop.As<ir::For>());
Expr root = this->GetRootBlock(block);
......@@ -1582,7 +1609,7 @@ bool ComputeInliner::BodyPatternAllowInline() {
return false;
}
CHECK(inlined_store_.As<Store>());
auto find_vars = ir::CollectIRNodesWithoutTensor(
auto find_vars = ir::ir_utils::CollectIRNodesWithoutTensor(
inlined_store_, [&](const Expr* x) { return x->as_var(); });
std::set<Var, CompVar> vars_set;
for (auto& i : find_vars) vars_set.insert(i.as_var_ref());
......@@ -1605,12 +1632,12 @@ void ComputeInliner::Visit(const ir::Load* expr, Expr* op) {
Expr ComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
Expr value_copy = ir::ir_utils::IRCopy(inlined_store_.As<Store>()->value);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}
void ScheduleImpl::ComputeInline(const Expr& schedule_block) {
void StScheduleImpl::ComputeInline(const Expr& schedule_block) {
CHECK(schedule_block.As<ir::ScheduleBlockRealize>());
Expr root = this->GetRootBlock(schedule_block);
Expr store = CheckComputeInlineValidationAndGetStore(schedule_block, root);
......@@ -1650,7 +1677,7 @@ bool ReverseComputeInliner::BodyPatternAllowInline() {
CHECK(inlined_store_.As<Store>());
CHECK(inlined_load_.As<Load>());
CHECK(target_store_.As<Store>());
auto find_vars = ir::CollectIRNodesWithoutTensor(
auto find_vars = ir::ir_utils::CollectIRNodesWithoutTensor(
inlined_store_, [&](const Expr* x) { return x->as_var(); });
std::set<Var, CompVar> vars_set;
for (auto& i : find_vars) vars_set.insert(i.as_var_ref());
......@@ -1681,7 +1708,7 @@ void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) {
Expr ReverseComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
Expr value_copy = ir::ir_utils::IRCopy(inlined_store_.As<Store>()->value);
return value_copy;
}
......@@ -1696,12 +1723,12 @@ Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) {
idx_sub_expr_.emplace_back(idx_vars_[i]);
}
Expr value_copy = optim::IRCopy(target_store_);
Expr value_copy = ir::ir_utils::IRCopy(target_store_);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}
void ScheduleImpl::ReverseComputeInline(const Expr& schedule_block) {
void StScheduleImpl::ReverseComputeInline(const Expr& schedule_block) {
Expr root = this->GetRootBlock(schedule_block);
auto exprs =
CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root);
......@@ -1777,170 +1804,55 @@ struct FindBlockParent : public ir::IRMutator<> {
ir::Expr* target_{nullptr};
};
Expr ScheduleImpl::AddUnitLoop(const Expr& block) const {
Expr StScheduleImpl::AddUnitLoop(const Expr& block) const {
auto exprs = module_expr_.GetExprs();
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
FindBlockParent visitor(block_name);
for (auto expr : exprs) {
visitor(&expr);
if (visitor.target_) {
break;
}
}
CHECK(visitor.target_) << ", block name : " << block_name << "\n" << exprs;
if (visitor.target_->As<ir::Block>()) {
for (auto& stmt : visitor.target_->As<ir::Block>()->stmts) {
if (stmt.As<ir::ScheduleBlockRealize>()) {
if (stmt.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == block_name) {
auto block = ir::Block::Make({GetBlock(block_name)});
auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::UNK,
block);
stmt = loop;
return loop;
}
}
}
} else if (visitor.target_->As<ir::For>()) {
auto block = ir::Block::Make({visitor.target_->As<ir::For>()->body});
auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::UNK,
block);
visitor.target_->As<ir::For>()->body = loop;
return loop;
} else if (visitor.target_->As<ir::ScheduleBlock>()) {
auto block =
ir::Block::Make({visitor.target_->As<ir::ScheduleBlock>()->body});
auto loop = ir::For::Make(ir::Var(common::UniqName("ix")),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::UNK,
block);
visitor.target_->As<ir::ScheduleBlock>()->body = loop;
return loop;
} else {
LOG(FATAL) << "Can't find block's parent!";
}
LOG(FATAL) << "Shouldn't reach code here in AddUnitLoop";
return Expr{nullptr};
return analyzer::AddUnitLoop(exprs, block);
}
std::vector<Expr> ScheduleImpl::GetLoops(const Expr& block) const {
std::vector<Expr> result;
std::vector<Expr> StScheduleImpl::GetLoops(const Expr& block) const {
auto exprs = module_expr_.GetExprs();
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
for (auto& it_expr : exprs) {
ir::FindLoopsVisitor visitor(block);
auto find_loops = visitor(&it_expr);
if (!find_loops.empty()) {
if (!result.empty())
LOG(FATAL) << "Find block with name: \n"
<< block_name << " appeared in more than one AST!";
result = find_loops;
}
}
if (result.empty()) {
result.push_back(AddUnitLoop(block));
}
return result;
return analyzer::GetLoops(exprs, block);
}
std::vector<Expr> ScheduleImpl::GetLoops(const std::string& block_name) const {
Expr block = this->GetBlock(block_name);
std::vector<Expr> result = this->GetLoops(block);
return result;
std::vector<Expr> StScheduleImpl::GetLoops(
const std::string& block_name) const {
auto exprs = module_expr_.GetExprs();
return analyzer::GetLoops(exprs, block_name);
}
std::vector<Expr> ScheduleImpl::GetAllBlocks() const {
std::vector<Expr> result;
std::vector<Expr> StScheduleImpl::GetAllBlocks() const {
auto exprs = module_expr_.GetExprs();
for (auto& it_expr : exprs) {
ir::FindBlocksVisitor visitor;
auto find_blocks = visitor(&it_expr);
result.insert(result.end(), find_blocks.begin(), find_blocks.end());
}
for (auto& it_expr : exprs) {
VLOG(3) << "it_expr is : " << it_expr;
}
CHECK(!result.empty()) << "Didn't find blocks in expr.";
return result;
return analyzer::GetAllBlocks(exprs);
}
std::vector<Expr> ScheduleImpl::GetChildBlocks(const Expr& expr) const {
CHECK(expr.As<ir::ScheduleBlockRealize>() || expr.As<ir::For>());
ir::FindBlocksVisitor visitor;
std::vector<Expr> result = visitor(&expr);
return result;
std::vector<Expr> StScheduleImpl::GetChildBlocks(const Expr& expr) const {
return analyzer::GetChildBlocks(expr);
}
bool ScheduleImpl::HasBlock(const std::string& block_name) const {
bool StScheduleImpl::HasBlock(const std::string& block_name) const {
auto exprs = module_expr_.GetExprs();
for (auto& it_expr : exprs) {
ir::FindBlocksVisitor visitor(block_name);
auto find_blocks = visitor(&it_expr);
if (!find_blocks.empty()) {
CHECK_EQ(find_blocks.size(), 1U)
<< "There should not be more than 1 block with identical name!";
return true;
}
}
return false;
return analyzer::HasBlock(exprs, block_name);
}
Expr ScheduleImpl::GetBlock(const std::string& block_name) const {
Expr result;
Expr StScheduleImpl::GetBlock(const std::string& block_name) const {
auto exprs = module_expr_.GetExprs();
for (auto& it_expr : exprs) {
ir::FindBlocksVisitor visitor(block_name);
auto find_blocks = visitor(&it_expr);
if (!find_blocks.empty()) {
CHECK_EQ(find_blocks.size(), 1U)
<< "There should not be more than 1 block with identical name!";
result = find_blocks[0];
return result;
}
}
LOG(FATAL) << "Didn't find a block with name " << block_name
<< " in this ModuleExpr!";
return analyzer::GetBlock(exprs, block_name);
}
void ScheduleImpl::Annotate(const Expr& block,
const std::string& key,
const attr_t& value) {
void StScheduleImpl::Annotate(const Expr& block,
const std::string& key,
const attr_t& value) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
auto copied_block = optim::IRCopy(block);
auto copied_block = ir::ir_utils::IRCopy(block);
auto* schedule_block = copied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
schedule_block->attrs.emplace(key, value);
this->Replace(block, copied_block);
}
void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) {
void StScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
......@@ -1954,8 +1866,8 @@ void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) {
}
}
void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
const bool flat_tensor) {
void StScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
const bool flat_tensor) {
CHECK_GT(loops.size(), 0) << "Loops can't be empty!";
VLOG(4) << "Before FlattenLoops, ir is:\n" << loops[0];
// compute loop
......@@ -2031,12 +1943,12 @@ void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
CHECK_EQ(iter.as_var_ref()->name, loop_vars[idx]->name)
<< "loops is not the same order with tensor!";
} else {
CHECK(iter.As<IntImm>());
CHECK(iter.As<IntImm>()) << iter.node_type() << " is not IntImm";
CHECK_EQ(iter.as_int32(), 0);
}
}
auto exprs = ir::CollectIRNodesInOrder(
auto exprs = ir::ir_utils::CollectIRNodesInOrder(
schedule_block->body,
[&](const Expr* x) { return x->As<ir::Store>() || x->As<ir::Load>(); });
// reverse exprs from last to first.
......@@ -2136,15 +2048,15 @@ void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
VLOG(4) << "After FlattenLoops, ir is:\n" << loop;
}
void ScheduleImpl::CopyTransformAndLoopInfo(
void StScheduleImpl::CopyTransformAndLoopInfo(
const std::string& block_name, const std::string& block_target_name) {
auto block = this->GetBlock(block_name);
auto block_target = this->GetBlock(block_target_name);
this->CopyTransformAndLoopInfo(block, block_target);
}
void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
const Expr& block_target) {
void StScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
const Expr& block_target) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block_target.As<ir::ScheduleBlockRealize>());
auto exprs = this->GetModule().GetExprs();
......@@ -2185,16 +2097,16 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
std::set<std::string> used_target_loop_vars;
for (auto& iter_val : new_iter_values) {
auto find_partial_loop =
ir::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) {
if (x->as_var()) used_target_loop_vars.insert(x->as_var_ref()->name);
return x->as_var();
});
}
CHECK(!used_target_loop_vars.empty());
std::vector<Expr> used_target_loops;
auto expr_copy = optim::IRCopy(expr);
auto expr_copy = ir::ir_utils::IRCopy(expr);
for (auto& var : used_target_loop_vars) {
auto find_loop_var = ir::CollectIRNodesWithoutTensor(
auto find_loop_var = ir::ir_utils::CollectIRNodesWithoutTensor(
expr_copy,
[&](const Expr* x) {
return x->As<ir::For>() && x->As<ir::For>()->loop_var->name == var &&
......@@ -2217,12 +2129,12 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
VLOG(3) << "changed_loop_num is : " << changed_loop_num;
VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size();
if (changed_loop_num >= static_cast<int>(old_iter_values.size())) {
new_loop = optim::IRCopy(block);
new_loop = ir::ir_utils::IRCopy(block);
new_loop.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
} else {
CHECK(old_iter_values[changed_loop_num].as_var());
auto old_var = old_iter_values[changed_loop_num].as_var_ref();
auto find_partial_loop = ir::CollectIRNodesWithoutTensor(
auto find_partial_loop = ir::ir_utils::CollectIRNodesWithoutTensor(
expr,
[&](const Expr* x) {
return x->As<ir::For>() &&
......@@ -2231,8 +2143,8 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
},
true);
CHECK_EQ(find_partial_loop.size(), 1U);
new_loop = optim::IRCopy(*find_partial_loop.begin());
auto find_schedule_block = ir::CollectIRNodesWithoutTensor(
new_loop = ir::ir_utils::IRCopy(*find_partial_loop.begin());
auto find_schedule_block = ir::ir_utils::CollectIRNodesWithoutTensor(
new_loop,
[&](const Expr* x) { return x->As<ir::ScheduleBlockRealize>(); },
true);
......@@ -2265,7 +2177,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
this->Replace(all_loops[0], res);
}
std::vector<Expr> ScheduleImpl::SamplePerfectTile(
std::vector<Expr> StScheduleImpl::SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
int n,
......@@ -2296,7 +2208,7 @@ std::vector<Expr> ScheduleImpl::SamplePerfectTile(
return result_expr;
}
Expr ScheduleImpl::SampleCategorical(
Expr StScheduleImpl::SampleCategorical(
utils::LinearRandomEngine::StateType* rand_seed,
const std::vector<int>& candidates,
const std::vector<float>& probs) {
......@@ -2314,41 +2226,52 @@ IRSchedule::IRSchedule() {}
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag,
utils::ErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
utils::ErrorMessageLevel err_msg_level,
bool is_dynamic_shape)
: impl_(ScheduleBase::Make(
module_expr, debug_flag, err_msg_level, is_dynamic_shape)),
is_dynamic_shape_(is_dynamic_shape) {
this->InitSeed(rand_seed);
}
IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed)
: impl_(std::make_unique<ScheduleImpl>(std::move(mod_expr))),
trace_(std::move(trace)) {
utils::LinearRandomEngine::StateType rand_seed,
bool is_dynamic_shape)
: impl_(ScheduleBase::Make(std::move(mod_expr), is_dynamic_shape)),
trace_(std::move(trace)),
is_dynamic_shape_(is_dynamic_shape) {
this->InitSeed(rand_seed);
}
IRSchedule::IRSchedule(const IRSchedule& other)
: impl_(std::make_unique<ScheduleImpl>(optim::IRCopy(other.GetModule()))),
trace_(other.trace_) {
: impl_(ScheduleBase::Make(ir::ir_utils::IRCopy(other.GetModule()),
other.IsDynamicShape())),
trace_(other.trace_),
is_dynamic_shape_(other.IsDynamicShape()) {
this->InitSeed(other.ForkSeed());
}
IRSchedule& IRSchedule::operator=(const IRSchedule& src) {
impl_ = std::make_unique<ScheduleImpl>(optim::IRCopy(src.GetModule()));
impl_ = ScheduleBase::Make(ir::ir_utils::IRCopy(src.GetModule()),
src.IsDynamicShape());
trace_ = src.trace_;
is_dynamic_shape_ = src.IsDynamicShape();
this->InitSeed(src.ForkSeed());
return *this;
}
IRSchedule::IRSchedule(IRSchedule&& other)
: impl_(std::move(other.impl_)), trace_(std::move(other.trace_)) {
: impl_(std::move(other.impl_)),
trace_(std::move(other.trace_)),
is_dynamic_shape_(other.IsDynamicShape()) {
this->InitSeed(other.ForkSeed());
}
IRSchedule& IRSchedule::operator=(IRSchedule&& src) {
impl_ = std::move(src.impl_);
trace_ = std::move(src.trace_);
is_dynamic_shape_ = src.IsDynamicShape();
this->InitSeed(src.ForkSeed());
return *this;
}
......@@ -2561,6 +2484,13 @@ void IRSchedule::SetBuffer(Expr& block,
{}));
}
Expr IRSchedule::AddUnitLoop(const Expr& block) {
Expr ret = impl_->AddUnitLoop(block);
trace_.Append(ScheduleDesc::Step(
"AddUnitLoop", {{"block", std::vector<Expr>({block})}}, {}, {ret}));
return ret;
}
Expr IRSchedule::Reorder(const std::vector<Expr>& loops) {
Expr ret = impl_->Reorder(loops);
trace_.Append(ScheduleDesc::Step("Reorder", {{"loops", loops}}, {}, {ret}));
......@@ -2643,6 +2573,15 @@ Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) {
return result;
}
Expr IRSchedule::FactorizeReduction(const Expr& rf_loop, int rf_axis) {
auto result = impl_->FactorizeReduction(rf_loop, rf_axis);
trace_.Append(ScheduleDesc::Step("FactorizeReduction",
{{"rf_loop", std::vector<Expr>({rf_loop})}},
{{"rf_axis", rf_axis}},
{result}));
return result;
}
void IRSchedule::Annotate(const Expr& block,
const std::string& key,
const attr_t& value) {
......
......@@ -21,51 +21,23 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/schedule/schedule_base.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"
namespace cinn {
namespace ir {
/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
*/
class ModuleExpr {
public:
ModuleExpr() = default;
ModuleExpr(const ModuleExpr& mod_expr) = default;
ModuleExpr(ModuleExpr&& mod_expr) = default;
ModuleExpr& operator=(const ModuleExpr& mod_expr) = default;
explicit ModuleExpr(const std::vector<Expr>& exprs) : exprs_(exprs) {}
explicit ModuleExpr(std::vector<Expr>&& exprs) : exprs_(std::move(exprs)) {}
//! Get all the Expr in this ModuleExpr.
std::vector<Expr> GetExprs() { return exprs_; }
std::vector<Expr> GetExprs() const { return exprs_; }
void SetExprs(const std::vector<Expr>& exprs) { exprs_ = exprs; }
private:
//! Exprs stored in ModuleExpr. Each one is an AST, representing a computation
//! kernel.
std::vector<Expr> exprs_;
};
/**
* A struct containing all the schedule primitives. Each shedule primitive is a
* member function of IRSchedule. Schedule primitves are implmented by
* ScheduleImpl manipulating the AST - IR(Expr). To support serializing and
* StScheduleImpl manipulating the AST - IR(Expr). To support serializing and
* replaying, each schedule primitive should append a ScheduleDesc::Step to the
* trace_ in its corresponding function implment.
*/
class ScheduleImpl;
class IRSchedule {
public:
IRSchedule();
......@@ -73,10 +45,12 @@ class IRSchedule {
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false,
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral);
utils::ErrorMessageLevel::kGeneral,
bool is_dynamic = false);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
utils::LinearRandomEngine::StateType rand_seed = -1,
bool is_dynamic = false);
IRSchedule(const IRSchedule& other);
IRSchedule& operator=(const IRSchedule& src);
IRSchedule(IRSchedule&& other);
......@@ -97,6 +71,8 @@ class IRSchedule {
//! Get the ScheduleDesc that traces the scheduling process
const ScheduleDesc& GetTraceDesc() const { return trace_; }
bool IsDynamicShape() const { return is_dynamic_shape_; }
/**
* \brief Get all the loops of specific Block stored in ModuleExpr.
* @param block The block we find loop in.
......@@ -244,7 +220,7 @@ class IRSchedule {
*/
void SyncThreads(const Expr& ir_node, bool after_node = true);
/*!
/**
* \brief Set a tensor's buffer type(memory_type)
* \param block The ScheduleBlockRealize corresponding to an unique tensor.
* \param memory_type The memory type we want to set. Should be "local",
......@@ -254,6 +230,13 @@ class IRSchedule {
const std::string& memory_type,
bool fixed = false); // NOLINT
/**
* \brief Create a new unit loop on top of the block.
* @param block The block to be added the new loop.
* @return The new unit loop.
*/
Expr AddUnitLoop(const Expr& block);
/**
* \brief Reorder the loops in the order of vector.
* @param loops The loops to be reordered.
......@@ -381,6 +364,46 @@ class IRSchedule {
*/
Expr Rfactor(const Expr& rf_loop, int rf_axis);
/**
* \brief Factorize the reduction block by the given loop. The block will be
* split into two blocks: reduction-factorized block and write-back block.
* @param rf_loop the reduce loop to be factorized.
* @param rf_axis The position where the new dimension is placed in the new rf
* tensor.
* @return The new created rf tensor.
*
* For example, input the block:
* \code
* for (i, 0, 10) // serial loop
* B_init[i] = 0
* for (j, 0, 20) // reduce loop
* for (k, 0, 30) // reduce loop
* B[i] = B[i] + A[i, j, k]
* \endcode
*
* If the rf loop is j and rf_axis is 0, the transformation is
* divided into 2 steps:
* 1. get the rf block where the reduce loop j is transformed to the
* serial loop with no accumalation and a new rf tensor is created.
* The axis j will be placed in the rf_axis of the new rf_tensor.
* The rf_block is as follows:
* \code
* for (i, 0, 10) // serial loop
* for (j, 0, 20) // rf loop j is transformed to the serial loop
* rf_B_init[j, i] = 0
* for (k, 0, 30) // reduce loop.
* rf_B[j, i] = rf_B[j, i] + A[i, j, k]
* \endcode
* 2. do reduction of the rf loop j to get the final result block:
* \code
* for (i, 0, 10) // serial loop
* B_init[i] = 0
* for (j, 0, 20) // rf reduction loop
* B[i] = B[i] + rf_B[j, i]
* \endcode
*/
Expr FactorizeReduction(const Expr& rf_loop, int rf_axis);
/*!
* \brief Annotate a block with a key-value pair to set as its attribute
* \param block The block to be annotated
......@@ -451,9 +474,10 @@ class IRSchedule {
utils::LinearRandomEngine::StateType ForkSeed() const;
private:
std::unique_ptr<ScheduleImpl> impl_;
std::unique_ptr<ScheduleBase> impl_;
mutable ScheduleDesc trace_; // trace the scheduling process
mutable utils::LinearRandomEngine::StateType rand_seed_;
bool is_dynamic_shape_;
};
/*!
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace ir {
......@@ -23,14 +23,14 @@ std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
std::ostringstream os;
os << "[IRScheduleError] An error occurred in the scheduel primitive < "
<< this->primitive_ << " >. " << std::endl;
os << this->err_msg_;
os << indent_str_ << "[Error info] " << this->err_msg_;
return os.str();
}
std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << "[Expr info] The Expr of current schedule is: "
os << indent_str_ << "[Expr info] The Expr of current schedule is:\n"
<< this->module_expr_.GetExprs() << std::endl;
return os.str();
}
......
......@@ -26,11 +26,11 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
......@@ -40,7 +40,7 @@ namespace ir {
Tensor GetTensor(const Expr& block) {
CHECK(block.As<ir::ScheduleBlockRealize>());
auto find_tensor = ir::CollectIRNodesWithoutTensor(
auto find_tensor = ir::ir_utils::CollectIRNodesWithoutTensor(
block, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
CHECK_EQ(find_tensor.size(), 1U)
<< "One block should only have one Store node!(except for root block)";
......@@ -52,13 +52,13 @@ Tensor GetTensor(const Expr& block) {
Tensor GetReadTensor(const Expr& block, int index) {
CHECK(block.As<ir::ScheduleBlockRealize>());
auto find_tensor = ir::CollectIRNodesWithoutTensor(
auto find_tensor = ir::ir_utils::CollectIRNodesWithoutTensor(
block, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
CHECK_EQ(find_tensor.size(), 1U)
<< "One block should only have one Store node!(except for root block)";
std::vector<Tensor> res;
auto find_read_tensor =
ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(block, [&](const Expr* x) {
if (x->As<ir::Load>())
res.push_back(x->As<ir::Load>()->tensor.as_tensor_ref());
return x->As<ir::Load>();
......@@ -86,41 +86,43 @@ void SetCudaAxisInfo(Expr* lowered_func) {
auto func_body = lowered_func->as_lowered_func_ref()->body;
CudaAxisInfo info;
auto block_nodes = ir::CollectIRNodes(func_body, [&](const Expr* x) {
if (x->As<ir::For>() && x->As<ir::For>()->bind_info().valid()) {
auto bind_info = x->As<ir::For>()->bind_info();
info.set_valid(true);
if (bind_info.for_type == ForType::GPUThread) {
CHECK(common::is_zero(x->As<ir::For>()->min));
CHECK(x->As<ir::For>()->extent.is_constant());
int range = x->As<ir::For>()->extent.get_constant();
range = range > info.block_dim(bind_info.offset)
? range
: info.block_dim(bind_info.offset);
VLOG(3) << "Set block dim[" << bind_info.offset << "] with range "
<< range;
info.set_block_dim(bind_info.offset, range);
} else if (bind_info.for_type == ForType::GPUBlock) {
CHECK(common::is_zero(x->As<ir::For>()->min));
CHECK(x->As<ir::For>()->extent.is_constant());
int range = x->As<ir::For>()->extent.get_constant();
range = range > info.grid_dim(bind_info.offset)
? range
: info.grid_dim(bind_info.offset);
info.set_grid_dim(bind_info.offset, range);
VLOG(3) << "Set grid dim[" << bind_info.offset << "] with range "
<< range;
} else {
LOG(FATAL) << "The for loop's bind info should be gpu block or thread!";
}
}
return (x->As<ir::For>() && x->As<ir::For>()->bind_info().valid());
});
auto block_nodes =
ir::ir_utils::CollectIRNodes(func_body, [&](const Expr* x) {
if (x->As<ir::For>() && x->As<ir::For>()->bind_info().valid()) {
auto bind_info = x->As<ir::For>()->bind_info();
info.set_valid(true);
if (bind_info.for_type == ForType::GPUThread) {
CHECK(common::is_zero(x->As<ir::For>()->min));
CHECK(x->As<ir::For>()->extent.is_constant());
int range = x->As<ir::For>()->extent.get_constant();
range = range > info.block_dim(bind_info.offset)
? range
: info.block_dim(bind_info.offset);
VLOG(3) << "Set block dim[" << bind_info.offset << "] with range "
<< range;
info.set_block_dim(bind_info.offset, range);
} else if (bind_info.for_type == ForType::GPUBlock) {
CHECK(common::is_zero(x->As<ir::For>()->min));
CHECK(x->As<ir::For>()->extent.is_constant());
int range = x->As<ir::For>()->extent.get_constant();
range = range > info.grid_dim(bind_info.offset)
? range
: info.grid_dim(bind_info.offset);
info.set_grid_dim(bind_info.offset, range);
VLOG(3) << "Set grid dim[" << bind_info.offset << "] with range "
<< range;
} else {
LOG(FATAL)
<< "The for loop's bind info should be gpu block or thread!";
}
}
return (x->As<ir::For>() && x->As<ir::For>()->bind_info().valid());
});
lowered_func->as_lowered_func_ref()->cuda_axis_info = info;
}
bool Contains(const Expr& container, const Expr& expr) {
auto find_expr = ir::CollectIRNodesWithoutTensor(
auto find_expr = ir::ir_utils::CollectIRNodesWithoutTensor(
container,
[&](const Expr* x) {
return (x->node_type() == expr.node_type() && *x == expr);
......@@ -219,6 +221,14 @@ void ReplaceExpr(Expr* source,
return;
}
void ReplaceExpr(Expr* source,
const std::map<Var, Expr, CompVar>& replacing_map) {
if (replacing_map.empty()) return;
MappingVarToExprMutator mapper(replacing_map);
mapper(source);
return;
}
std::vector<int> ValidateFactors(const std::vector<int>& factors,
int total_extent,
const ModuleExpr& module_expr) {
......@@ -283,13 +293,13 @@ void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) {
auto* rf_for = rf_loop.As<ir::For>();
CHECK(rf_for) << "Expr param of Rfactor must be For node! Please check.";
// check the rf_loop only has one schedule block
auto block_nodes = ir::CollectIRNodesWithoutTensor(
auto block_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
rf_loop,
[&](const Expr* x) { return x->As<ScheduleBlockRealize>(); },
true);
CHECK_EQ(block_nodes.size(), 1U)
<< "Rfactor Loop should only have one schedule block";
auto find_store = ir::CollectIRNodesWithoutTensor(
auto find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
rf_loop, [&](const Expr* x) { return x->As<Store>(); }, true);
CHECK_EQ(find_store.size(), 1U);
auto indice = find_store.begin()->As<Store>()->indices;
......@@ -322,9 +332,9 @@ void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) {
}
std::vector<Expr> GetLoopsOfExpr(const Expr& expr, const Expr& root) {
auto loop_nodes = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::For>() && Contains(*x, expr);
});
auto loop_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) { return x->As<ir::For>() && Contains(*x, expr); });
std::vector<Expr> result(loop_nodes.begin(), loop_nodes.end());
if (result.empty())
LOG(FATAL) << "Didn't find expr's : \n"
......@@ -346,8 +356,8 @@ IterRange GetAccessedRange(const Expr& index,
var_maxs.emplace_back(range.min + range.extent - 1);
}
Expr indice_min = optim::IRCopy(index);
Expr indice_max = optim::IRCopy(index);
Expr indice_min = ir::ir_utils::IRCopy(index);
Expr indice_max = ir::ir_utils::IRCopy(index);
// replace the var by the corresponding iter_value
ReplaceExpr(&indice_min, iter_vars, var_mins);
ReplaceExpr(&indice_max, iter_vars, var_maxs);
......@@ -357,8 +367,16 @@ IterRange GetAccessedRange(const Expr& index,
Expr indice_extent;
Expr mod_extent(0);
if (indice_min.As<Mod>() && indice_min.As<Mod>()->b().is_constant())
if (indice_min.As<Mod>() && indice_min.As<Mod>()->b().is_constant()) {
Expr mod_right_min = indice_min.As<Mod>()->a();
Expr mod_right_max = indice_max.As<Mod>()->a();
Expr mod_right_extent =
common::AutoSimplify(mod_right_max - mod_right_min + 1);
mod_extent = indice_min.As<Mod>()->b();
if (mod_right_extent.get_constant() < mod_extent.get_constant()) {
mod_extent = mod_right_extent;
}
}
if (indice_min == indice_max) {
if (common::is_zero(mod_extent)) {
......@@ -406,7 +424,7 @@ std::vector<IterRange> CalculateTensorRegions(
std::vector<IterRange> result;
for (int i = 0; i < tensor_indices.size(); ++i) {
Expr binded_index = optim::IRCopy(tensor_indices[i]);
Expr binded_index = ir::ir_utils::IRCopy(tensor_indices[i]);
ReplaceExpr(&binded_index, iter_vars, iter_values);
auto range = GetAccessedRange(binded_index, loop_vars, loop_ranges);
......@@ -439,8 +457,8 @@ Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) {
->body;
if (is_write) {
std::vector<Expr> find_store_vec;
auto find_store =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
auto find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
if (x->As<ir::Store>()) find_store_vec.push_back(*x);
return x->As<ir::Store>();
});
......@@ -450,8 +468,8 @@ Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) {
return store_index;
} else {
std::vector<Expr> find_load_vec;
auto find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
auto find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
if (x->As<ir::Load>()) find_load_vec.push_back(*x);
return x->As<ir::Load>();
});
......@@ -526,7 +544,7 @@ void FindInsertionPoint(const Expr& root, CacheBlockInfo* info, bool is_write) {
Expr find_tensor =
is_write ? Expr(info->write_tensor) : Expr(info->read_tensor);
auto find_produce_read =
ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() && x->As<ir::Store>()->tensor == find_tensor;
});
......@@ -654,7 +672,7 @@ Expr ConstructOtherStmtChain(const std::vector<Expr>& stmts,
const std::vector<int> reordered_indices) {
Expr new_loop;
for (int i = reordered_indices.size() - 1; i >= 0; --i) {
Expr temp = optim::IRCopy(loops[reordered_indices[i]]);
Expr temp = ir::ir_utils::IRCopy(loops[reordered_indices[i]]);
CHECK(temp.defined());
CHECK(temp.As<ir::For>());
if (new_loop.defined()) {
......@@ -675,9 +693,9 @@ Expr ConstructNewLoopChain(const std::vector<Expr>& chain,
// In each IfThenElse node, find the vars its condition depends on.
for (auto& if_expr : if_nodes) {
CHECK(if_expr.As<IfThenElse>());
auto var_set =
ir::CollectIRNodes(if_expr.As<IfThenElse>()->condition,
[&](const Expr* x) { return x->as_var(); });
auto var_set = ir::ir_utils::CollectIRNodes(
if_expr.As<IfThenElse>()->condition,
[&](const Expr* x) { return x->as_var(); });
std::set<std::string> var_name_set;
for (auto& i : var_set) var_name_set.insert(i.as_var()->name);
condition_vars.push_back(var_name_set);
......@@ -693,10 +711,10 @@ Expr ConstructNewLoopChain(const std::vector<Expr>& chain,
Expr temp;
if (loop_set.count(loop_in_chain)) {
CHECK_GE(index, 0);
temp = optim::IRCopy(ordered_loops[index]);
temp = ir::ir_utils::IRCopy(ordered_loops[index]);
--index;
} else {
temp = optim::IRCopy(loop_in_chain);
temp = ir::ir_utils::IRCopy(loop_in_chain);
}
CHECK(temp.defined());
CHECK(temp.As<ir::For>());
......@@ -863,9 +881,9 @@ std::vector<Expr> GetProducers(const Expr& block, const Expr& root) {
std::string block_name = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
ir::CollectIRNodesWithoutTensor(
ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&producer_tensor_names, &block_name](const Expr* x) {
auto* load = x->As<ir::Load>();
const ir::Load* load = x->As<ir::Load>();
if (load) {
producer_tensor_names.insert(load->tensor.as_tensor()->name);
if (load->tensor.as_tensor()->name == block_name) {
......@@ -874,20 +892,36 @@ std::vector<Expr> GetProducers(const Expr& block, const Expr& root) {
}
return true;
}
const ir::Store* store = x->As<ir::Store>();
if (store) {
std::set<ir::Expr> call_nodes =
ir::ir_utils::CollectIRNodesWithoutTensor(
store->value,
[](const ir::Expr* x) { return x->As<ir::Call>(); });
for (ir::Expr call : call_nodes) {
const std::vector<ir::Expr>& read_args =
call.As<ir::Call>()->read_args;
for (const ir::Expr& arg : read_args) {
if (arg.as_tensor()) {
producer_tensor_names.insert(arg.as_tensor_ref()->name);
}
}
}
}
return false;
});
// traverse each of other blocks and filter those ones which contain at least
// one producer tensor;
auto find_blocks =
ir::CollectIRNodesWithoutTensor(root, [&block, &root](const Expr* x) {
auto find_blocks = ir::ir_utils::CollectIRNodesWithoutTensor(
root, [&block, &root](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x != block && *x != root;
});
for (auto&& cur : find_blocks) {
auto* cur_block = cur.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
CHECK(cur_block) << "block result should be a ScheduleBlockRealize";
auto find_stores = ir::CollectIRNodesWithoutTensor(
auto find_stores = ir::ir_utils::CollectIRNodesWithoutTensor(
cur_block->body, [&producer_tensor_names](const Expr* x) {
return x->As<ir::Store>() &&
producer_tensor_names.count(
......@@ -905,32 +939,44 @@ std::vector<Expr> GetConsumers(const Expr& block, const Expr& root) {
std::string block_tensor = GetTensor(block)->name;
if (IsReduceInitTensorName(block_tensor)) {
std::string consumer_name = GetOriginalReduceTensorName(block_tensor);
auto consumer = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == consumer_name;
});
auto consumer =
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name == consumer_name;
});
CHECK_EQ(consumer.size(), 1);
return {*consumer.begin()};
}
auto find_block = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x != block && *x != root;
});
auto find_block =
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x != block && *x != root;
});
for (auto& i : find_block) {
CHECK(i.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
auto block_body = i.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body;
auto find_load =
ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) {
auto find_load_or_call = ir::ir_utils::CollectIRNodesWithoutTensor(
block_body, [&](const Expr* x) {
if (x->As<ir::Call>()) {
const std::vector<ir::Expr>& read_args =
x->As<ir::Call>()->read_args;
for (const ir::Expr& arg : read_args) {
if (arg.as_tensor() &&
arg.as_tensor_ref()->name == block_tensor) {
return true;
}
}
}
return x->As<ir::Load>() &&
x->As<ir::Load>()->tensor.as_tensor_ref()->name ==
block_tensor;
});
if (!find_load.empty()) consumers.emplace_back(i);
if (!find_load_or_call.empty()) consumers.emplace_back(i);
}
return consumers;
}
......@@ -938,7 +984,7 @@ std::vector<Expr> GetConsumers(const Expr& block, const Expr& root) {
void CheckComputeAtValidation(const Expr& block,
const Expr& loop,
const Expr& root) {
auto find_block = ir::CollectIRNodesWithoutTensor(
auto find_block = ir::ir_utils::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x == block;
......@@ -946,13 +992,13 @@ void CheckComputeAtValidation(const Expr& block,
true);
CHECK(!find_block.empty()) << "Didn't find block in root!";
auto find_loop = ir::CollectIRNodesWithoutTensor(
auto find_loop = ir::ir_utils::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) { return x->As<ir::For>() && *x == loop; },
true);
CHECK(!find_loop.empty()) << "Didn't find loop in root!";
auto find_block_in_loop = ir::CollectIRNodesWithoutTensor(
auto find_block_in_loop = ir::ir_utils::CollectIRNodesWithoutTensor(
loop,
[&](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() && *x == block;
......@@ -1005,10 +1051,10 @@ std::vector<IterRange> CalculateRequiredRegions(
std::set<Expr> provided_nodes;
if (is_store_provided) {
provided_nodes = ir::CollectIRNodesWithoutTensor(
provided_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
block, [&](const Expr* x) { return x->As<ir::Store>(); });
} else {
provided_nodes = ir::CollectIRNodesWithoutTensor(
provided_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
block, [&](const Expr* x) { return x->As<ir::Load>(); });
}
......@@ -1025,9 +1071,9 @@ std::vector<IterRange> CalculateRequiredRegions(
for (const Expr& req_block : required_blocks) {
CHECK(req_block.As<ir::ScheduleBlockRealize>());
Expr block_body =
optim::IRCopy(req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
ir::ir_utils::IRCopy(req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
auto iter_vars = req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
......@@ -1036,7 +1082,7 @@ std::vector<IterRange> CalculateRequiredRegions(
// Notice that we look for For nodes in loop's body instead of loop
// itself.
auto find_loops = ir::CollectIRNodesWithoutTensor(
auto find_loops = ir::ir_utils::CollectIRNodesWithoutTensor(
loop.As<ir::For>()->body, [&](const Expr* x) {
return x->As<ir::For>() && Contains(*x, req_block);
});
......@@ -1052,15 +1098,15 @@ std::vector<IterRange> CalculateRequiredRegions(
std::set<Expr> required_nodes;
if (is_store_provided) {
required_nodes =
ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) {
required_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
block_body, [&](const Expr* x) {
return x->As<ir::Load>() &&
x->As<ir::Load>()->tensor.as_tensor_ref()->name ==
provided_tensor_name;
});
} else {
required_nodes =
ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) {
required_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
block_body, [&](const Expr* x) {
return x->As<ir::Store>() &&
x->As<ir::Store>()->tensor.as_tensor_ref()->name ==
provided_tensor_name;
......@@ -1105,7 +1151,7 @@ std::vector<IterRange> CalculateRequiredRegions(
block.As<ir::ScheduleBlockRealize>()->iter_values[i].is_constant());
if (block.As<ir::ScheduleBlockRealize>()->iter_values[i].as_var()) {
auto find_for_loops =
ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::For>() &&
x->As<ir::For>()->loop_var->name ==
block.As<ir::ScheduleBlockRealize>()
......@@ -1134,13 +1180,13 @@ Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block,
->schedule_block.As<ir::ScheduleBlock>()
->body;
// 1. Check the schedule block to be inlined is not a reduce tensor.
auto find_store = ir::CollectIRNodesWithoutTensor(
auto find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
CHECK_EQ(find_store.size(), 1U);
Expr tensor = (*find_store.begin()).As<ir::Store>()->tensor;
CHECK(!tensor.as_tensor_ref()->is_reduce_tensor());
// 2. Check this schedule block is the only writer of the tensor.
find_store = ir::CollectIRNodesWithoutTensor(
find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) {
return x->As<ir::Store>() &&
......@@ -1151,8 +1197,8 @@ Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block,
CHECK_EQ(find_store.size(), 1U);
// 3. Check there is no overlap between the buffers the schedule block reads
// and writes.
auto find_load =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
auto find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor;
});
CHECK(find_load.empty());
......@@ -1166,14 +1212,14 @@ std::tuple<Expr, Expr, Expr> CheckReverseComputeInlineValidationAndGetExprs(
->schedule_block.As<ir::ScheduleBlock>()
->body;
// 1. Check the schedule block to be reverse inlined is not a reduce tensor.
auto find_inlined_load = ir::CollectIRNodesWithoutTensor(
auto find_inlined_load = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Load>(); }, true);
CHECK_EQ(find_inlined_load.size(), 1U);
Expr tensor = (*find_inlined_load.begin()).As<ir::Load>()->tensor;
CHECK(!tensor.as_tensor_ref()->is_reduce_tensor());
auto inlined_load = *find_inlined_load.begin();
// 2. Check this schedule block is the only reader of the tensor.
auto find_load = ir::CollectIRNodesWithoutTensor(
auto find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) {
return x->As<ir::Load>() &&
......@@ -1184,20 +1230,20 @@ std::tuple<Expr, Expr, Expr> CheckReverseComputeInlineValidationAndGetExprs(
CHECK_EQ(find_load.size(), 1U);
// 3. Check there is no overlap between the buffers the schedule block reads
// and writes.
auto find_store =
ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) {
auto find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
return x->As<ir::Store>() && x->As<ir::Store>()->tensor == tensor;
});
CHECK(find_store.empty());
// 4. Get store that will be inlined.
auto find_inlined_store =
ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() && x->As<ir::Store>()->tensor == tensor;
});
CHECK_EQ(find_inlined_store.size(), 1U);
auto inlined_store = *find_inlined_store.begin();
// 5. Get target store.
auto find_target_store = ir::CollectIRNodesWithoutTensor(
auto find_target_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
CHECK_EQ(find_target_store.size(), 1U);
auto target_store = *find_target_store.begin();
......@@ -1206,7 +1252,7 @@ std::tuple<Expr, Expr, Expr> CheckReverseComputeInlineValidationAndGetExprs(
bool ContainVar(const std::vector<Expr>& exprs, const std::string& var_name) {
for (auto& expr : exprs) {
auto find_expr = ir::CollectIRNodesWithoutTensor(
auto find_expr = ir::ir_utils::CollectIRNodesWithoutTensor(
expr,
[&](const Expr* x) {
return x->As<_Var_>() && x->As<_Var_>()->name == var_name;
......
......@@ -22,9 +22,9 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/random_engine.h"
#include "paddle/cinn/utils/string.h"
......@@ -193,7 +193,7 @@ Tensor GetReadTensor(const Expr& block, int index);
int GetLoopExtent(const Expr& loop);
/**
* \brief Given a vector of Exors, return whether they contain a var with
* \brief Given a vector of Exprs, return whether they contain a var with
* specific name.
* @param exprs The given vector of Exprs
* @param var_name The name of specific var
......@@ -241,6 +241,15 @@ void ReplaceExpr(Expr* source,
const std::vector<Var>& replaced,
const std::vector<Expr>& candidates);
/**
* Replace Vars in replaced to Exprs in candidates in source.
* @param source The Expr we will implement the change.
* @param replacing_map The one-to-one corresponded Vars -> Exprs to be
* replaced.
*/
void ReplaceExpr(Expr* source,
const std::map<Var, Expr, CompVar>& replacing_map);
/**
* Validate the factors param of Split. We will check if factors are validate
* and change -1 to positive integer.
......@@ -427,9 +436,11 @@ IterRange RangeUnion(const IterRange& range1, const IterRange& range2);
* \param loop The loop where we will insert the block under it
* @param root The root of the whole AST.
* \param required_blocks vector of ScheduleBlockRealize nodes that require the
* block \param is_store_provided Whether Store nodes of the block provide the
* block
* \param is_store_provided Whether Store nodes of the block provide the
* tensor, true means it is in compute_at case, otherwise false means in
* reverse_compuate_at case \return Each index's range of block's tensor.
* reverse_compuate_at case
* \return Each index's range and can_keep_loop flag of block's tensor.
* Indicating the buffer region being required.
*/
std::vector<IterRange> CalculateRequiredRegions(
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/schedule/schedule_base.h"
namespace cinn {
namespace ir {
/**
* Replace a node to another node.
* @param src_sref The node to be changed.
* @param tgt_stmt The node we want.
*/
void ScheduleBase::Replace(const Expr& src_sref, const Expr& tgt_stmt) {
CHECK(src_sref.As<ir::For>() || src_sref.As<ir::Block>() ||
src_sref.As<ir::ScheduleBlockRealize>());
CHECK(tgt_stmt.As<ir::For>() || tgt_stmt.As<ir::Block>() ||
tgt_stmt.As<ir::ScheduleBlockRealize>());
if (src_sref == tgt_stmt) {
return;
}
struct ForLoopMutator : public ir::IRMutator<> {
ForLoopMutator(const Expr& source, const Expr& target)
: source_(source), target_(target) {}
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::For* op, Expr* expr) override {
if (*expr == source_) {
*expr = target_;
return;
}
ir::IRMutator<>::Visit(op, expr);
}
void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override {
if (*expr == source_) {
*expr = target_;
return;
}
ir::IRMutator<>::Visit(op, expr);
}
void Visit(const ir::Block* op, Expr* expr) override {
if (*expr == source_) {
*expr = target_;
return;
}
ir::IRMutator<>::Visit(op, expr);
}
const Expr& source_;
const Expr& target_;
};
auto exprs = module_expr_.GetExprs();
ForLoopMutator mutator(src_sref, tgt_stmt);
for (auto& i : exprs) {
mutator(&i);
}
}
} // namespace ir
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"
PD_DECLARE_int32(cinn_error_message_level);
namespace cinn {
namespace ir {
/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
*/
class ModuleExpr {
public:
ModuleExpr() = default;
ModuleExpr(const ModuleExpr& mod_expr) = default;
ModuleExpr(ModuleExpr&& mod_expr) = default;
ModuleExpr& operator=(const ModuleExpr& mod_expr) = default;
explicit ModuleExpr(const std::vector<Expr>& exprs) : exprs_(exprs) {}
explicit ModuleExpr(std::vector<Expr>&& exprs) : exprs_(std::move(exprs)) {}
//! Get all the Expr in this ModuleExpr.
std::vector<Expr> GetExprs() { return exprs_; }
std::vector<Expr> GetExprs() const { return exprs_; }
void SetExprs(const std::vector<Expr>& exprs) { exprs_ = exprs; }
private:
//! Exprs stored in ModuleExpr. Each one is an AST, representing a computation
//! kernel.
std::vector<Expr> exprs_;
};
/**
* Define the interface for scheduling primitives,
* with subclasses DyScheduleImpl and StScheduleImpl.
*/
class ScheduleBase {
public:
ScheduleBase() = delete;
explicit ScheduleBase(const ModuleExpr& module_expr,
bool debug_flag = false,
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<utils::ErrorMessageLevel>(
FLAGS_cinn_error_message_level || static_cast<int>(err_msg_level));
}
explicit ScheduleBase(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}
static std::unique_ptr<ScheduleBase> Make(
const ModuleExpr& module_expr,
bool debug_flag = false,
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral,
bool is_dynamic = false);
static std::unique_ptr<ScheduleBase> Make(ModuleExpr&& module_expr,
bool is_dynamic = false);
void SetDebugFlag(bool debug_flag) { debug_flag_ = debug_flag; }
const ModuleExpr& GetModule() const { return module_expr_; }
void SetExprs(const std::vector<Expr>& exprs) {
module_expr_.SetExprs(exprs);
}
virtual void MergeExprs() = 0;
virtual bool HasBlock(const std::string& block_name) const = 0;
virtual std::vector<Expr> GetLoops(const Expr& block) const = 0;
virtual std::vector<Expr> GetLoops(const std::string& block_name) const = 0;
virtual std::vector<Expr> GetAllBlocks() const = 0;
virtual std::vector<Expr> GetChildBlocks(const Expr& expr) const = 0;
virtual Expr GetBlock(const std::string& block_name) const = 0;
virtual std::vector<Expr> Split(const Expr& loop,
const std::vector<int>& factors) = 0;
virtual std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
int n,
int max_innermost_factor) = 0;
virtual Expr Fuse(const std::vector<Expr>& loops) = 0;
virtual Expr Fuse(const std::string& block_name,
const std::vector<int>& loops_index) = 0;
virtual Expr Fuse(const Expr& block, const std::vector<int>& loops_index) = 0;
virtual void ComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) = 0;
virtual void SimpleComputeAt(const Expr& block, const Expr& loop) = 0;
virtual void ReverseComputeAt(const Expr& block,
const Expr& loop,
bool keep_unit_loops) = 0;
virtual Expr GetRootBlock(const Expr& expr) const = 0;
virtual Expr CacheRead(const Expr& block,
int read_buffer_index,
const std::string& memory_type) = 0;
virtual Expr CacheWrite(const Expr& block,
int write_buffer_index,
const std::string& memory_type) = 0;
virtual void SyncThreads(const Expr& ir_node, bool after_node = true) = 0;
virtual void SetBuffer(Expr& block, // NOLINT
const std::string& memory_type,
bool fixed = false) = 0;
virtual Expr Reorder(const std::vector<Expr>& loops) = 0;
virtual Expr Reorder(const std::string& block_name,
const std::vector<int>& loops_index) = 0;
virtual Expr Reorder(const Expr& block,
const std::vector<int>& loops_index) = 0;
virtual DeviceAPI GetDeviceAPI() const = 0;
virtual void MutateForType(const Expr& loop,
ForType for_type,
int factor = -1) = 0;
virtual void Parallel(const Expr& loop) = 0;
virtual void Vectorize(const Expr& loop, int factor) = 0;
virtual void Unroll(const Expr& loop) = 0;
virtual void ComputeInline(const Expr& schedule_block) = 0;
virtual void ReverseComputeInline(const Expr& schedule_block) = 0;
virtual void Bind(const Expr& loop, const std::string& thread_axis) = 0;
virtual Expr Rfactor(const Expr& rf_loop, int rf_axis) = 0;
virtual Expr FactorizeReduction(const Expr& rf_loop, int rf_axis) = 0;
virtual Expr AddUnitLoop(const Expr& block) const = 0;
virtual void Annotate(const Expr& block,
const std::string& key,
const attr_t& value) = 0;
virtual void Unannotate(Expr& block, const std::string& key) = 0; // NOLINT
virtual void FlattenLoops(const std::vector<Expr>& loops,
const bool force_flat = false) = 0;
virtual void CopyTransformAndLoopInfo(const Expr& block,
const Expr& block_target) = 0;
virtual void CopyTransformAndLoopInfo(
const std::string& block_name, const std::string& block_target_name) = 0;
virtual Expr SampleCategorical(
utils::LinearRandomEngine::StateType* rand_seed,
const std::vector<int>& candidates,
const std::vector<float>& probs) = 0;
protected:
void Replace(const Expr& src_sref, const Expr& tgt_stmt);
ModuleExpr module_expr_;
bool debug_flag_{false};
utils::ErrorMessageLevel err_msg_level_ = utils::ErrorMessageLevel::kGeneral;
};
} // namespace ir
} // namespace cinn
......@@ -422,6 +422,12 @@ CINN_BUILD_STEP_KIND(SetBuffer)
.SetApplyFn(
APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SetBuffer)));
CINN_BUILD_STEP_KIND(AddUnitLoop)
.Inputs({"block"})
.SetApplyFn(APPLY_FUNC_UNIFORM(
FREE_FUNCTION_CONVERTER(static_cast<Expr (IRSchedule::*)(const Expr&)>(
&IRSchedule::AddUnitLoop))));
CINN_BUILD_STEP_KIND(Reorder).Inputs({"loops"}).SetApplyFn(
APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(
static_cast<Expr (IRSchedule::*)(const std::vector<Expr>&)>(
......@@ -474,6 +480,12 @@ CINN_BUILD_STEP_KIND(Rfactor)
.SetApplyFn(
APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Rfactor)));
CINN_BUILD_STEP_KIND(FactorizeReduction)
.Inputs({"rf_loop"})
.Attrs({"rf_axis"})
.SetApplyFn(APPLY_FUNC_UNIFORM(
FREE_FUNCTION_CONVERTER(&IRSchedule::FactorizeReduction)));
CINN_BUILD_STEP_KIND(MergeExprs)
.SetApplyFn(
APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::MergeExprs)));
......
......@@ -14,8 +14,8 @@
#include "paddle/cinn/ir/schedule_block_graph.h"
#include "paddle/cinn/common/dfs_topo_walker.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace ir {
......
......@@ -20,11 +20,9 @@
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
using Group = cinn::hlir::framework::Graph::Group;
namespace cinn {
namespace ir {
......
......@@ -16,6 +16,7 @@
#include <cstring>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/axis.h"
......@@ -23,10 +24,10 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/poly/isl_utils.h"
#include "paddle/cinn/poly/stage.h"
......@@ -52,6 +53,67 @@ Tensor _Tensor_::Make(const std::string &name,
return Tensor(n);
}
Tensor _Tensor_::Make(const std::string &name,
Type dtype,
const std::vector<Expr> &shape,
const std::vector<Expr> &domain,
const std::vector<Var> &reduce_axis) {
CHECK(!name.empty()) << "Cannot set empty Tensor name in Tensor::Make";
auto n = make_shared<_Tensor_>();
n->name = name;
n->shape = shape;
n->domain = domain;
n->reduce_axis = reduce_axis;
n->operation = PlaceholderOp::Make(n->name, n->shape, Float(32));
n->set_type(dtype);
n->InitAxis();
return Tensor(n);
}
Tensor _Tensor_::Make(const std::string &name,
Type dtype,
const std::vector<Dim> &sym_shape,
const std::vector<Expr> &domain,
FunctionRef fn,
const std::vector<Var> &reduce_axis) {
CHECK(!name.empty()) << "Tensor name is set empty";
auto n = make_shared<_Tensor_>();
n->name = name;
n->sym_shape = sym_shape;
n->shape.reserve(sym_shape.size());
for (int i = 0; i < sym_shape.size(); i++) {
n->shape[i] = sym_shape[i]->dim_expr;
}
n->domain = domain;
n->reduce_axis = reduce_axis;
n->set_type(dtype);
n->operation = fn;
n->InitAxis();
return Tensor(n);
}
Tensor _Tensor_::Make(const std::string &name,
Type dtype,
const std::vector<Dim> &sym_shape,
const std::vector<Expr> &domain,
const std::vector<Var> &reduce_axis) {
CHECK(!name.empty()) << "Cannot set empty Tensor name in Tensor::Make";
auto n = make_shared<_Tensor_>();
n->name = name;
n->sym_shape = sym_shape;
n->shape.reserve(sym_shape.size());
for (int i = 0; i < sym_shape.size(); i++) {
n->shape[i] = sym_shape[i]->dim_expr;
}
n->domain = domain;
n->reduce_axis = reduce_axis;
n->operation = PlaceholderOp::Make(n->name, n->shape, Float(32));
n->set_type(dtype);
n->InitAxis();
return Tensor(n);
}
size_t Tensor::ndims() const { return operator->()->shape.size(); }
......@@ -59,7 +121,7 @@ std::set<std::string> _Tensor_::GetDependTensorNames() const {
std::set<std::string> names;
auto add_depend_tensors_from_expr = [&](Expr expr) {
auto tensors = CollectIRNodes(expr, [&](const Expr *x) {
auto tensors = ir::ir_utils::CollectIRNodes(expr, [&](const Expr *x) {
return x->as_tensor() && x->as_tensor()->name != this->name;
});
for (auto &e : tensors) {
......@@ -514,7 +576,7 @@ bool _Tensor_::IsDependOnStatement(absl::string_view statement) {
std::set<std::string> _Tensor_::DependingTensorNames() {
std::set<std::string> res;
if (body().defined()) {
auto depend_tensors = ir::CollectIRNodes(
auto depend_tensors = ir::ir_utils::CollectIRNodes(
body(), [](const Expr *x) -> bool { return x->as_tensor(); });
for (const auto &x : depend_tensors) {
if (x.get() != this) {
......@@ -537,7 +599,7 @@ std::vector<Var> _Tensor_::axis_with_reduce() const {
}
bool _Tensor_::Uses(const Tensor &other) const {
auto loads = ir::CollectIRNodes(body(), [&](const Expr *x) {
auto loads = ir::ir_utils::CollectIRNodes(body(), [&](const Expr *x) {
auto *loadn = x->As<ir::Load>();
if (!loadn) return false;
return loadn->tensor.as_tensor()->name == other->name;
......
......@@ -25,36 +25,23 @@
#include <utility>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/lang/buffer.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace ir {
class Tensor;
} // namespace ir
namespace lang {
template <typename T>
struct Placeholder;
void InitReduceTensor(poly::StageMap stages,
const ir::Tensor& tensor,
const Target& target = common::DefaultHostTarget());
} // namespace lang
namespace ast_gen_ius {
class TensorGroup;
} // namespace ast_gen_ius
namespace ir {
namespace detail {
constexpr bool LE(int a, int b) { return a <= b; }
constexpr bool GE(int a, int b) { return a >= b; }
} // namespace detail
class _Tensor_;
class Tensor;
class Tensor : public ir::IrNodeRef {
public:
......@@ -84,8 +71,8 @@ class Tensor : public ir::IrNodeRef {
return operator()(std::vector<Expr>({a}));
}
template <typename... Args>
inline typename std::enable_if<detail::GE(sizeof...(Args), 2), Expr>::type
operator()(Args&&... args) const {
inline typename std::enable_if<sizeof...(Args) >= 2, Expr>::type operator()(
Args&&... args) const {
return operator()({std::forward<Args>(args)...});
}
// @}
......@@ -135,6 +122,8 @@ struct WriteCacheRelation;
*/
class _Tensor_ : public ExprNode<_Tensor_> {
public:
//! Symbolic Shape of this tensor(buffer).
std::vector<Dim> sym_shape;
//! Shape of this tensor(buffer).
std::vector<Expr> shape;
//! The domain of each axis(without reduce_axis)
......@@ -163,6 +152,28 @@ class _Tensor_ : public ExprNode<_Tensor_> {
FunctionRef fn,
const std::vector<Var>& reduce_axis = {});
// Manual tensor construction, no FunctionRef information
static Tensor Make(const std::string& name,
Type dtype,
const std::vector<Expr>& shape,
const std::vector<Expr>& domain,
const std::vector<Var>& reduce_axis = {});
//! (Symbolic Shape) Generate a tensor from a function.
static Tensor Make(const std::string& name,
Type dtype,
const std::vector<Dim>& sym_shape,
const std::vector<Expr>& domain,
FunctionRef fn,
const std::vector<Var>& reduce_axis = {});
// (Symbolic Shape) Manual tensor construction, no FunctionRef information
static Tensor Make(const std::string& name,
Type dtype,
const std::vector<Dim>& sym_shape,
const std::vector<Expr>& domain,
const std::vector<Var>& reduce_axis = {});
void Verify() const override;
bool IsReduceInited(poly::StageMap stages) const;
......@@ -288,12 +299,6 @@ class _Tensor_ : public ExprNode<_Tensor_> {
poly::StageMap stages,
const Target& target = common::DefaultHostTarget()) const;
private:
//! Initialize the axis field after the shape field is assigned.
void InitAxis() const;
isl::set GenerateIslDomain() const;
/**
* Create the initialization tensor.
* @param stages The stages.
......@@ -304,15 +309,17 @@ class _Tensor_ : public ExprNode<_Tensor_> {
poly::StageMap stages,
const Target& target = common::DefaultHostTarget()) const;
private:
//! Initialize the axis field after the shape field is assigned.
void InitAxis() const;
isl::set GenerateIslDomain() const;
//! The names of the tensors depend the same buffer and should schedule before
//! this.
std::set<std::string> buffer_depended_tensor_names_;
friend Shared<poly::Stage> CreateStage(Tensor tensor);
friend void lang::InitReduceTensor(poly::StageMap stages,
const ir::Tensor& tensor,
const Target& target);
};
Shared<poly::Stage> CreateStage(Tensor tensor);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment