Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/cast_simplify.h"
#include <gtest/gtest.h>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn::optim {
TEST(CastSimplify, same_type) {
......@@ -26,7 +24,7 @@ TEST(CastSimplify, same_type) {
Expr a = ir::Cast::Make(Int(32), n);
LOG(INFO) << n->type();
LOG(INFO) << a;
CastSimplify(&a);
SimplifyCast(&a);
ASSERT_EQ(utils::GetStreamCnt(a), "n");
}
......@@ -34,7 +32,7 @@ TEST(CastSimplify, Imm_int) {
Expr a = ir::Cast::Make(Int(64), Expr(1));
Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), Int(32));
......@@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) {
Expr a = ir::Cast::Make(Float(64), Expr(2.33));
Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "2");
ASSERT_EQ(c.type(), Int(32));
......@@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) {
Expr a = ir::Cast::Make(UInt(64), Expr(1));
Expr c = ir::Cast::Make(UInt(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), UInt(32));
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/collect_undefined_vars.h"
#include <set>
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace cinn::optim {
namespace {
struct Mutator : public ir::IRMutator<> {
using ir::IRMutator<>::Visit;
std::vector<std::string> undefined_vars;
std::set<std::string> defined_vars;
std::set<std::string> used_vars;
void CollectVarDef(const std::string& var) {
CHECK(!defined_vars.count(var))
<< "var " << var << " has been defined, please check";
CHECK(!used_vars.count(var))
<< "var " << var << " is wrongly used before definition";
defined_vars.insert(var);
}
void ClearVar(const std::string& var) {
defined_vars.erase(var);
used_vars.erase(var);
}
void CollectVarUse(const std::string& var) {
used_vars.insert(var);
if (defined_vars.count(var) == 0) {
undefined_vars.push_back(var);
}
}
void Visit(const ir::Let* op, Expr* expr) final {
Expr symbol = op->symbol;
auto var = symbol.as_var_ref();
CHECK(var.defined());
CollectVarDef(var->name);
auto* node = expr->As<ir::Let>();
Visit(&node->body, &node->body);
}
void Visit(const ir::For* op, Expr* expr) final {
CollectVarDef(op->loop_var->name);
auto* node = expr->As<ir::For>();
Visit(&node->min, &node->min);
Visit(&node->extent, &node->extent);
Visit(&node->body, &node->body);
ClearVar(op->loop_var->name);
}
void Visit(const ir::Load* op, Expr* expr) final {
auto tensor = op->tensor.as_tensor_ref();
CollectVarUse(tensor->name);
auto* node = expr->As<ir::Load>();
for (auto& idx : node->indices) Visit(&idx, &idx);
}
void Visit(const ir::Store* op, Expr* expr) final {
auto tensor = op->tensor.as_tensor_ref();
CollectVarUse(tensor->name);
auto* node = expr->As<ir::Store>();
for (auto& idx : node->indices) Visit(&idx, &idx);
Visit(&node->value, &node->value);
}
void Visit(const ir::_Var_* op, Expr* expr) final {
CollectVarUse(op->name);
auto* node = expr->As<ir::_Var_>();
if (node->lower_bound.defined()) {
Visit(&node->lower_bound, &node->lower_bound);
}
if (node->upper_bound.defined()) {
Visit(&node->upper_bound, &node->upper_bound);
}
}
void Visit(const ir::Reduce* op, Expr* expr) final {
for (auto& axis : op->reduce_axis) {
CollectVarDef(axis->name);
}
auto* node = expr->As<ir::Reduce>();
if (node->init.defined()) Visit(&node->init, &node->init);
Visit(&node->body, &node->body);
}
};
} // namespace
std::vector<std::string> CollectUndefinedVars(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
return mutator.undefined_vars;
}
} // namespace cinn::optim
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/ir/ir.h"
namespace cinn::optim {
/**
* Collect undefined vars in the scope.
*
* e.g.
*
* The expression:
* for i
* for j
* a[i, j] = b[i, j]
*
* here a, b are vars without definition
*/
std::vector<std::string> CollectUndefinedVars(Expr* e);
} // namespace cinn::optim
......@@ -18,8 +18,8 @@
#include <string>
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace cinn {
......@@ -150,7 +150,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {
}
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (int i = 0; i < node->indices.size(); i++) {
auto temp = optim::IRCopy(node->indices[i]);
auto temp = ir::ir_utils::IRCopy(node->indices[i]);
ir::IRMutator<>::Visit(&temp, &temp);
node->indices[i] = temp;
}
......@@ -159,7 +159,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {
} else {
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (int i = 0; i < node->indices.size(); i++) {
auto temp = optim::IRCopy(node->indices[i]);
auto temp = ir::ir_utils::IRCopy(node->indices[i]);
ir::IRMutator<>::Visit(&temp, &temp);
node->indices[i] = temp;
}
......@@ -167,7 +167,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {
} else {
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (int i = 0; i < node->indices.size(); i++) {
auto temp = optim::IRCopy(node->indices[i]);
auto temp = ir::ir_utils::IRCopy(node->indices[i]);
ir::IRMutator<>::Visit(&temp, &temp);
node->indices[i] = temp;
}
......@@ -225,7 +225,7 @@ void ComputeInlineExpand(Expr *expr,
poly::StageMap stages,
std::map<std::string, ir::Tensor> *all_tensor_map) {
// the inline tensors contained in the expression.
auto inline_tensors = ir::CollectIRNodes(*expr, [&](const Expr *x) {
auto inline_tensors = ir::ir_utils::CollectIRNodes(*expr, [&](const Expr *x) {
return x->as_tensor() && stages[x->as_tensor()]->inlined();
});
......@@ -240,9 +240,10 @@ void ComputeInlineExpand(Expr *expr,
TensorInlineExpandMutator(tensor->name, all_tensor_map, stages)(expr);
}
inline_tensors = ir::CollectLoadTensors(*expr, [&](const Expr *x) {
return x->as_tensor() && stages[x->as_tensor()]->inlined();
});
inline_tensors =
ir::ir_utils::CollectLoadTensors(*expr, [&](const Expr *x) {
return x->as_tensor() && stages[x->as_tensor()]->inlined();
});
}
}
......
......@@ -17,10 +17,10 @@
#include <tuple>
#include <vector>
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
namespace cinn {
namespace optim {
......@@ -36,9 +36,9 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
auto* node = expr->As<ir::Store>();
auto broadcasts = ir::CollectIRNodes(node->value, [&](const Expr* expr) {
return expr->As<ir::Broadcast>();
});
auto broadcasts = ir::ir_utils::CollectIRNodes(
node->value,
[&](const Expr* expr) { return expr->As<ir::Broadcast>(); });
std::vector<Expr> let_exprs;
Var tmp;
......@@ -54,7 +54,7 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
std::tie(let_expr, tmp) = CreateTmpLet(broadcast);
let_exprs.push_back(let_expr);
optim::IrReplace(expr, broadcast, tmp);
cinn::ir::ir_utils::IrReplace(expr, broadcast, tmp);
}
// insert the let expressions to the outer forloop.
......@@ -79,7 +79,7 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
}
bool ContainsLoopVar(Expr expr, Var loop_var) {
return !ir::CollectIRNodes(expr, [&](const Expr* e) -> bool {
return !ir::ir_utils::CollectIRNodes(expr, [&](const Expr* e) -> bool {
return e->As<ir::_Var_>() &&
e->As<ir::_Var_>()->name == loop_var->name;
}).empty();
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace cinn {
namespace optim {
......
......@@ -17,8 +17,8 @@
#include <unordered_set>
#include <vector>
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
......@@ -19,8 +19,8 @@
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
......
......@@ -24,18 +24,19 @@
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace optim {
using namespace ir; // NOLINT
using common::bfloat16;
using common::ExprToGinacConverter;
using common::float16;
using utils::GetStreamCnt;
using utils::Replace;
......@@ -53,9 +54,9 @@ void PartialSimplify(
}
//! Simplify the expression but Load.
struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
struct SimplifyNoPureMathMutator : public ir::IRMutator<ir::Expr*> {
common::cas_intervals_t& var_intervals;
explicit SimplifyButStoreLoadMutator(
explicit SimplifyNoPureMathMutator(
common::cas_intervals_t& var_intervals) // NOLINT
: var_intervals(var_intervals) {}
......@@ -76,19 +77,6 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
__(Max)
#undef __
void Visit(const Ramp* op, Expr* expr) override {
auto* node = expr->As<Ramp>();
CHECK(common::IsPureMath(node->base));
CHECK(common::IsPureMath(node->stride));
PartialSimplify(&node->base, var_intervals);
PartialSimplify(&node->stride, var_intervals);
}
void Visit(const Cast* op, Expr* expr) override {
auto* node = expr->As<Cast>();
Visit(&node->v(), &node->v());
}
void Visit(const PolyFor* op, Expr* expr) override {
auto* node = expr->As<ir::PolyFor>();
node->condition = common::SolveInequality(op->condition, op->iterator);
......@@ -138,7 +126,7 @@ struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
if (common::IsPureMath(idx)) {
PartialSimplify(&idx, var_intervals_);
} else {
SimplifyButStoreLoadMutator mutator(var_intervals_);
SimplifyNoPureMathMutator mutator(var_intervals_);
mutator(&idx);
}
}
......@@ -176,7 +164,7 @@ struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
if (common::IsPureMath(idx)) {
PartialSimplify(&idx, var_intervals_);
} else {
SimplifyButStoreLoadMutator mutator(var_intervals_);
SimplifyNoPureMathMutator mutator(var_intervals_);
mutator(&idx);
}
}
......@@ -215,8 +203,8 @@ struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
CHECK(common::IsPureMath(node->stride))
<< node->stride << "is not a pure math!";
Simplify(&node->base);
Simplify(&node->stride);
PartialSimplify(&node->base);
PartialSimplify(&node->stride);
}
// ramp + ramp
void Visit(const Add* op, Expr* expr) override {
......@@ -317,6 +305,33 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> {
expr->As<ir::Block>()->stmts = stmts;
}
}
void Visit(const ScheduleBlock* op, Expr* expr) override {
auto* node = expr->As<ScheduleBlock>();
CHECK(node);
for (auto& var : node->iter_vars) {
if (var->lower_bound.defined()) {
Visit(&var->lower_bound, &var->lower_bound);
}
if (var->upper_bound.defined()) {
Visit(&var->upper_bound, &var->upper_bound);
}
}
for (auto& buffer_region : node->read_buffers) {
Visit(&buffer_region, &buffer_region);
}
for (auto& buffer_region : node->write_buffers) {
Visit(&buffer_region, &buffer_region);
}
if (node->body.As<Block>()) {
if (node->body.As<Block>()->stmts.size() == 1) {
node->body = node->body.As<Block>()->stmts[0];
}
}
Visit(&(node->body), &(node->body));
}
};
struct SimplifyForLoopsMutator : public ir::IRMutator<> {
......@@ -359,23 +374,108 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> {
}
};
template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}
if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}
struct SimplifyCastMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { ir::IRMutator<ir::Expr*>::Visit(expr, expr); }
void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();
ir::IRMutator<ir::Expr*>::Visit(&node->v(), &node->v());
if (op->type() == op->v().type()) {
*expr = op->v();
return;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if (op->v().is_constant()) {
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<bfloat16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(bfloat16)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
} // namespace
void Simplify(Expr* expr) {
VLOG(3) << "Begin Simplify " << *expr;
optim::CastSimplify(expr);
SimplifyCastMutator()(expr);
SimplifyRampMutator()(expr);
SimplifyLoadMutator()(expr);
SimplifyStoreMutator()(expr);
SimplifyIfThenElseMutator()(expr);
common::cas_intervals_t var_intervals;
SimplifyButStoreLoadMutator mutator(var_intervals);
SimplifyNoPureMathMutator mutator(var_intervals);
mutator(expr);
ReplaceFracWithDivMutator()(expr);
}
void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); }
void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }
......
......@@ -30,6 +30,8 @@ namespace optim {
*/
void Simplify(Expr *expr);
void SimplifyCast(Expr *expr);
void SimplifyForLoops(Expr *expr);
void SimplifyBlocks(Expr *expr);
......
......@@ -17,7 +17,7 @@
#include <string>
#include <vector>
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace cinn {
namespace optim {
......
......@@ -19,8 +19,8 @@
#include "paddle/cinn/backends/llvm/llvm_intrin_rule.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/registry.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace cinn {
namespace optim {
......
......@@ -16,7 +16,7 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/hlir/op/op_util.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/runtime/cpu/host_intrinsics.h"
namespace cinn {
......
......@@ -14,12 +14,11 @@
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/call_arg_list_to_pod_value.h"
#include "paddle/cinn/optim/cast_bool_to_int8.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h"
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
......@@ -28,9 +27,9 @@
#include "paddle/cinn/optim/lower_function_call_bind_vars.h"
#include "paddle/cinn/optim/lower_intrin.h"
#include "paddle/cinn/optim/map_extern_call.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/optim/unroll_loops.h"
......@@ -44,13 +43,14 @@ Expr Optimize(Expr e,
bool runtime_debug_info,
bool remove_gpu_for_loops) {
CHECK(e.defined());
auto copied = IRCopy(e);
auto copied = ir::ir_utils::IRCopy(e);
FoldCINNCallArguments(&copied);
TransformPolyForToFor(&copied);
ReplaceConstParamToInteger(&copied);
// Simplify already contains CastSimplify
Simplify(&copied);
ReplaceCrossThreadReduction(&copied);
UnrollLoop(&copied);
VLOG(4) << "After Optimize UnrollLoop:" << copied;
......@@ -66,8 +66,8 @@ Expr Optimize(Expr e,
CudaSyncThreadsDropIfThenElse(&copied);
#endif
RemoveNestedBlock(&copied);
VLOG(4) << "After Optimize RemoveNestedBlock:" << copied;
SimplifyBlocks(&copied);
VLOG(4) << "After SimplifyBlocks:" << copied;
MapExternCall(&copied, target);
VLOG(10) << "After Optimize MapExternCall:" << copied;
......@@ -86,7 +86,8 @@ Expr Optimize(Expr e,
}
ir::Module Optimize(const ir::Module& module, const Target& target) {
auto copied = IRCopy(Expr(module));
auto copied = ir::ir_utils::IRCopy(Expr(module));
ReplaceCrossThreadReduction(&copied);
UnrollLoop(&copied);
VectorizeLoops(&copied, Target());
VLOG(10) << "After VectorizeLoops:" << copied.as_module_ref();
......
......@@ -17,7 +17,7 @@
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace optim {
Expr GetExprInsideBlock(Expr op) {
Expr node = op;
while (node.As<ir::Block>()) {
auto& stmts = node.As<ir::Block>()->stmts;
if (stmts.size() == 1) {
node = stmts.front();
} else {
break;
}
}
return node;
}
// This will remove the nested blocks, but it will also remove the block outside
// the forloop's body.
struct NestedBlockSimplifer : public ir::IRMutator<Expr*> {
void operator()(ir::Expr* expr) { Visit(expr); }
private:
void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::Block* expr, Expr* op) override {
auto* node = op->As<ir::Block>();
if (node->stmts.size() == 1) {
*op = GetExprInsideBlock(*op);
IRMutator::Visit(op, op);
} else {
IRMutator::Visit(expr, op);
}
}
};
struct NestedBlockRemover : public ir::IRMutator<Expr*> {
void operator()(ir::Expr* expr) { Visit(expr); }
private:
void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::Block* expr, Expr* op) override {
auto* node = op->As<ir::Block>();
std::vector<ir::Expr> new_exprs;
bool detect_nested = false;
for (auto it = node->stmts.begin(); it != node->stmts.end(); it++) {
auto* block = it->As<ir::Block>();
if (block) {
detect_nested = true;
new_exprs.insert(
std::end(new_exprs), block->stmts.begin(), block->stmts.end());
} else {
new_exprs.push_back(*it);
}
}
node->stmts = new_exprs;
IRMutator::Visit(expr, op);
}
};
// add block outside forloop's body.
struct AddBlockToForloop : public ir::IRMutator<> {
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::For* expr, Expr* op) override {
auto* node = op->As<ir::For>();
if (!node->body.As<ir::Block>()) {
node->body = ir::Block::Make({node->body});
}
ir::IRMutator<>::Visit(expr, op);
}
void Visit(const ir::PolyFor* expr, Expr* op) override {
auto* node = op->As<ir::PolyFor>();
if (!node->body.As<ir::Block>()) {
node->body = ir::Block::Make({node->body});
}
ir::IRMutator<>::Visit(expr, op);
}
void Visit(const ir::_LoweredFunc_* expr, Expr* op) override {
auto* node = op->As<ir::_LoweredFunc_>();
if (!node->body.As<ir::Block>()) {
node->body = ir::Block::Make({node->body});
}
ir::IRMutator<>::Visit(expr, op);
}
};
void RemoveNestedBlock(Expr* e) {
NestedBlockRemover()(e);
NestedBlockSimplifer()(e);
AddBlockToForloop()(e);
}
} // namespace optim
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file implements the strategy to remove the unnecessary nested block.
*/
#pragma once
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace optim {
/**
* Remove the unecessary nested block.
*/
void RemoveNestedBlock(Expr* e);
} // namespace optim
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/remove_nested_block.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace optim {
TEST(RemoveNestedBlock, basic) {
auto block0 = ir::Block::Make({Expr(1.f), Expr(1.f)});
auto block1 = ir::Block::Make({block0});
auto e = Expr(block1);
std::string origin = utils::GetStreamCnt(e);
EXPECT_EQ(origin, utils::Trim(R"ROC(
{
{
1.00000000f
1.00000000f
}
}
)ROC"));
std::cout << "origin:\n" << e << std::endl;
RemoveNestedBlock(&e);
std::cout << "e:\n" << e << std::endl;
EXPECT_EQ(utils::GetStreamCnt(e), utils::Trim(R"ROC(
{
1.00000000f
1.00000000f
}
)ROC"));
}
} // namespace optim
} // namespace cinn
......@@ -14,8 +14,8 @@
#include "paddle/cinn/optim/remove_schedule_block.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace cinn {
......
......@@ -21,8 +21,8 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment