Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -14,9 +14,9 @@
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace cinn {
......@@ -36,7 +36,7 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> {
VLOG(3) << "Processing Call node " << *op;
if (statement_ != node->name) return;
Expr expr_candidate = IRCopy(candidate_);
Expr expr_candidate = ir::ir_utils::IRCopy(candidate_);
VLOG(3) << "Original candidate expr: " << candidate_;
VLOG(3) << "Copied candidate expr: " << expr_candidate;
......@@ -62,7 +62,7 @@ void ReplaceIslCallWithExpr(Expr *e,
const Expr &candidate,
const std::map<std::string, Expr> &axis_map) {
VLOG(3) << "ReplaceCallWithExpr, original expression: " << candidate;
Expr copied = IRCopy(candidate);
Expr copied = ir::ir_utils::IRCopy(candidate);
// update the axis in the copied expression.
// we treat the Store node as the normal statement, the others like Call node
......
......@@ -17,8 +17,8 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/poly/ast_gen.h"
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/poly/ast_gen.h"
#include "paddle/cinn/utils/string.h"
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file implements the strategy to remove the unnecessary nested block.
*/
#pragma once
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/pe/reduction.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/lang/compute.h"
namespace cinn {
namespace optim {
namespace {
struct BufferCmp {
bool operator()(const ir::Buffer& a, const ir::Buffer& b) const {
if (a->name == b->name) return false;
return true;
}
};
thread_local std::set<ir::Buffer, BufferCmp> shm_buffer_;
struct CrossThreadReductionReplacer : public ir::IRMutator<> {
void operator()(ir::Expr* expr) { Visit(expr); }
private:
bool CanReplace(const ir::ScheduleBlockRealize* block_realize) {
const ir::ScheduleBlock* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(schedule_block);
if (block_realize->schedule_block.As<ir::ScheduleBlock>()->name.substr(
0, 4) == "root") {
return false;
}
const std::vector<ir::Expr>& iter_values = block_realize->iter_values;
const std::vector<ir::Var>& iter_vars = schedule_block->iter_vars;
ir::Expr body = schedule_block->body;
std::unordered_set<std::string> reduce_var_names;
for (int i = 0; i < iter_values.size(); ++i) {
if (!iter_vars[i]->is_reduce_axis) {
continue;
}
ir::ir_utils::CollectIRNodesWithoutTensor(
iter_values[i], [&](const ir::Expr* x) {
if (x->as_var()) {
reduce_var_names.insert(x->as_var()->name);
}
return false;
});
}
std::vector<int> thread_binded_reduce_loop_indices;
for (int i = 0; i < cur_loops_.size(); ++i) {
if (reduce_var_names.count(cur_loops_[i].As<ir::For>()->loop_var->name) >
0) {
if (cur_loops_[i].As<ir::For>()->is_gpu_thread_binded()) {
if (ir::GetLoopExtent(cur_loops_[i]) > 1024) {
return false;
}
thread_binded_reduce_loop_indices.push_back(i);
}
}
}
if (thread_binded_reduce_loop_indices.size() == 0 ||
thread_binded_reduce_loop_indices.back() != cur_loops_.size() - 1) {
return false;
}
for (int i = 1; i < thread_binded_reduce_loop_indices.size(); ++i) {
if (thread_binded_reduce_loop_indices[i - 1] + 1 !=
thread_binded_reduce_loop_indices[i]) {
return false;
}
}
return true;
}
void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void Visit(const ir::_LoweredFunc_* expr, ir::Expr* op) override {
ir::IRMutator<>::Visit(expr, op);
if (std::find_if(op->as_lowered_func()->temp_bufs.begin(),
op->as_lowered_func()->temp_bufs.end(),
[&](const ir::Buffer& buf) -> bool {
for (auto& tmp_buf : shm_buffer_) {
if (buf->name == tmp_buf->name) return true;
}
return false;
}) == op->as_lowered_func()->temp_bufs.end())
op->as_lowered_func()->temp_bufs.insert(
op->as_lowered_func()->temp_bufs.end(),
shm_buffer_.begin(),
shm_buffer_.end());
shm_buffer_.clear();
}
void Visit(const ir::ScheduleBlockRealize* expr, ir::Expr* op) override {
if (!CanReplace(expr)) {
VLOG(6) << "Can't replace cross thread reduction: " << *op;
IRMutator::Visit(expr, op);
return;
}
VLOG(6) << "Can replace cross thread reduction: " << *op;
const ir::ScheduleBlock* schedule_block =
expr->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(schedule_block);
ir::Expr original_update_body = schedule_block->body;
ir::Expr original_update_stmt;
CHECK(original_update_body.As<ir::Block>() ||
original_update_body.As<ir::Store>());
if (original_update_body.As<ir::Block>()) {
CHECK_EQ(original_update_body.As<ir::Block>()->stmts.size(), 1);
original_update_stmt = original_update_body.As<ir::Block>()->stmts[0];
} else if (original_update_body.As<ir::Store>()) {
original_update_stmt = original_update_body;
}
#define REPLACE_TO_EXTERNAL_CALL(Op) \
if (original_update_stmt.As<ir::Store>()->value.As<Op>()) { \
auto* node = original_update_stmt.As<ir::Store>()->value.As<Op>(); \
CHECK(node); \
auto& operand = node->b(); \
std::string reduce_func_name = \
hlir::pe::CrossThreadReduceExternalFuncName( \
original_update_stmt.As<ir::Store>()->value, \
operand.As<ir::Load>()->tensor); \
auto tmp_dtype = operand.As<ir::Load>()->tensor.as_tensor()->type(); \
auto tmp_buffer = ir::_Buffer_::Make( \
"shm32_" + hlir::pe::Type2StrForReduce(tmp_dtype) + "_reduce", \
{ir::Expr(32)}); \
tmp_buffer->dtype = tmp_dtype; \
tmp_buffer->memory_type = ir::MemoryType::GPUShared; \
shm_buffer_.insert(tmp_buffer); \
original_update_stmt.As<ir::Store>()->value = \
lang::CallExtern(reduce_func_name, {node->b(), tmp_buffer}); \
}
REPLACE_TO_EXTERNAL_CALL(ir::Add)
REPLACE_TO_EXTERNAL_CALL(ir::Mul)
REPLACE_TO_EXTERNAL_CALL(ir::Max)
REPLACE_TO_EXTERNAL_CALL(ir::Min)
REPLACE_TO_EXTERNAL_CALL(ir::And)
REPLACE_TO_EXTERNAL_CALL(ir::Or)
#undef REPLACE_TO_EXTERNAL_CALL
VLOG(6) << "Replace cross thread reduction: " << *op;
IRMutator::Visit(expr, op);
}
void Visit(const ir::For* expr, ir::Expr* op) override {
cur_loops_.push_back(*op);
IRMutator::Visit(expr, op);
cur_loops_.pop_back();
}
private:
std::vector<ir::Expr> cur_loops_;
};
} // namespace
void ReplaceCrossThreadReduction(Expr* e) { CrossThreadReductionReplacer()(e); }
} // namespace optim
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file implements the strategy to remove the unnecessary nested block.
*/
#pragma once
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace optim {
/**
* Replace cross thread reduction to external call.
*/
void ReplaceCrossThreadReduction(Expr* e);
} // namespace optim
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace optim {
TEST(CrossThreadReductionReplacer, basic) {
#ifdef CINN_WITH_CUDA
Context::Global().ResetNameId();
Placeholder<float> A("A", {Expr(64), Expr(128)});
Target target = common::DefaultNVGPUTarget();
Module::Builder builder("reduce_sum", target);
Var reduce_j(128, "reduce_j");
ir::Tensor B = Compute(
{Expr(64)},
[&](Var i) { return lang::ReduceSum(A(i, reduce_j), {reduce_j}); },
"B");
ast_gen_ius::TensorGroup tensor_group({A, B});
auto func = lang::LowerToAst("reduce_sum", {A, B}, &tensor_group);
VLOG(6) << "original func\n" << func;
ir::ModuleExpr mod_expr({func->body});
ir::IRSchedule ir_sch(mod_expr);
ir_sch.Bind(ir_sch.GetLoops("B")[0], "blockIdx.x");
ir_sch.Bind(ir_sch.GetLoops("B")[1], "threadIdx.x");
ir::Expr new_func = ir_sch.GetModule().GetExprs()[0];
VLOG(6) << "After Bind: " << new_func;
ReplaceCrossThreadReduction(&new_func);
VLOG(6) << "After ReplaceCrossThreadReduction: " << new_func;
EXPECT_EQ(utils::GetStreamCnt(new_func), utils::Trim(R"ROC({
ScheduleBlock(root)
{
thread_bind[blockIdx.x] for (i, 0, 64)
{
ScheduleBlock(B__reduce_init)
{
i0 = axis.bind(i)
B__reduce_init[i0] = 0.00000000f
}
thread_bind[threadIdx.x] for (reduce_j, 0, 128)
{
ScheduleBlock(B)
{
i0_0, i1 = axis.bind(i, reduce_j)
B[i0_0] = cinn_block_reduce_sum_fp32_internal_shm(A[i0_0, i1], _Buffer_<cinn_buffer_t*: 32>(shm32__fp32_reduce))
}
}
}
}
}
)ROC"));
#endif
}
} // namespace optim
} // namespace cinn
......@@ -16,11 +16,11 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
......@@ -41,7 +41,7 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<> {
private:
void Visit(const ir::_Var_* expr, Expr* op) override {
if (expr->name == var_->name && (do_replace_ || visit_all_)) {
auto copied = IRCopy(expr_);
auto copied = ir::ir_utils::IRCopy(expr_);
*op = copied;
}
}
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/tensor_write_tell.h"
namespace cinn {
namespace optim {} // namespace optim
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace cinn {
namespace optim {
struct TensorWriteTeller : public ir::IRMutator<const Expr*> {
//! Collect the write info in \p op.
void Collect(const Expr* op) { Visit(op, op); }
bool IsWrite(const std::string& tensor_name) const {
return tensor_written.count(tensor_name);
}
private:
std::set<std::string> tensor_written;
void Visit(const Expr* expr, const Expr* op) override {
IRMutator::Visit(expr, op);
}
void Visit(const ir::Store* expr, const Expr* op) override {
auto* node = op->As<ir::Store>();
CHECK(node);
auto* tensor = node->tensor.As<ir::_Tensor_>();
CHECK(tensor);
tensor_written.insert(tensor->name);
IRMutator::Visit(expr, op);
}
void Visit(const ir::_Tensor_* op, const Expr* expr) override {
auto* node = expr->As<ir::_Tensor_>();
if (node->is_call_node()) {
tensor_written.insert(node->name);
}
}
};
} // namespace optim
} // namespace cinn
......@@ -24,9 +24,9 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/poly/isl_utils.h"
......@@ -185,7 +185,7 @@ class RestructureVarNodes : public ir::IRMutator<> {
void Visit(const ir::Load *load, Expr *op) override {
std::vector<ir::Expr> indices_copied;
for (const ir::Expr &indice : load->indices) {
indices_copied.push_back(IRCopy(indice));
indices_copied.push_back(ir::ir_utils::IRCopy(indice));
}
op->As<ir::Load>()->indices = indices_copied;
......@@ -195,7 +195,7 @@ class RestructureVarNodes : public ir::IRMutator<> {
void Visit(const ir::Store *store, Expr *op) override {
std::vector<ir::Expr> indices_copied;
for (const ir::Expr &indice : store->indices) {
indices_copied.push_back(IRCopy(indice));
indices_copied.push_back(ir::ir_utils::IRCopy(indice));
}
op->As<ir::Store>()->indices = indices_copied;
......@@ -396,7 +396,7 @@ class ReplaceLoopVarToGpu : public ir::IRMutator<> {
auto bind_info = for_ir->bind_info();
std::string var_name = "";
if (bind_info.offset == 0)
if (bind_info.offset <= 0)
var_name = "x";
else if (bind_info.offset == 1)
var_name = "y";
......@@ -585,8 +585,8 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> {
}
int BufferSize(ir::Expr indice) {
auto copy = IRCopy(indice);
auto vars = ir::CollectIRNodesInOrder(
auto copy = ir::ir_utils::IRCopy(indice);
auto vars = ir::ir_utils::CollectIRNodesInOrder(
copy, [](const ir::Expr *expr) { return expr->As<ir::_Var_>(); });
int max_range = 1;
......@@ -598,7 +598,7 @@ class ResizeBufferSizeVisitor : public ir::IRMutator<> {
auto extent = loop_2_extent_.find(var->name)->second;
for (int idx = 0; idx < extent; ++idx) {
auto tmp = IRCopy(index);
auto tmp = ir::ir_utils::IRCopy(index);
ReplaceVarWithExpr(&tmp, var, Expr(idx));
if (deep == vars.size() - 1) {
......
......@@ -21,11 +21,11 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn {
......
......@@ -17,11 +17,11 @@
#include <utility>
#include <vector>
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
namespace cinn {
namespace optim {
......@@ -94,8 +94,8 @@ struct UnrollMutator : public ir::IRMutator<Expr*> {
for (int i = min->value; i < extent->value; i++) {
Expr start = op->min + i;
body.push_back(optim::IRCopy(op->body));
optim::IrReplace(&body.back(), op->loop_var, start);
body.push_back(ir::ir_utils::IRCopy(op->body));
cinn::ir::ir_utils::IrReplace(&body.back(), op->loop_var, start);
}
*expr = ir::Block::Make(body);
......
......@@ -17,8 +17,8 @@
#include <absl/container/flat_hash_map.h>
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn::optim {
......
......@@ -25,13 +25,12 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/unroll_loops.h"
#include "paddle/cinn/utils/functional.h"
......@@ -130,7 +129,8 @@ class TensorVectorizeTeller : public ir::IRMutator<const Expr *> {
// the iter val must appear in the last index
if (indices.empty() ||
ir::CollectIRNodes(indices.back(), find_matched_var_fn).empty()) {
ir::ir_utils::CollectIRNodes(indices.back(), find_matched_var_fn)
.empty()) {
VLOG(5) << "Loop var:" << iter_var_->name
<< " is not used in the last index";
return false;
......@@ -138,7 +138,8 @@ class TensorVectorizeTeller : public ir::IRMutator<const Expr *> {
// the iter val can't appear in mulitple indices
for (int i = 0; i < indices.size() - 1; ++i) {
auto repeat_found = ir::CollectIRNodes(indices[i], find_matched_var_fn);
auto repeat_found =
ir::ir_utils::CollectIRNodes(indices[i], find_matched_var_fn);
if (!repeat_found.empty()) {
VLOG(5) << "Loop var:" << iter_var_->name
<< " is used at more than last index, current:" << i;
......@@ -147,12 +148,12 @@ class TensorVectorizeTeller : public ir::IRMutator<const Expr *> {
}
// check tensor accessed sequentially by comparing index one by one
Expr first_idx = optim::IRCopy(indices.back());
optim::IrReplace(&first_idx, Expr(iter_var_), Expr(0));
Expr first_idx = ir::ir_utils::IRCopy(indices.back());
cinn::ir::ir_utils::IrReplace(&first_idx, Expr(iter_var_), Expr(0));
const auto &interval = var_intervals_->at(iter_var_->name);
for (int i = 1; i < interval.r; ++i) {
Expr next_idx = optim::IRCopy(indices.back());
optim::IrReplace(&next_idx, Expr(iter_var_), Expr(i));
Expr next_idx = ir::ir_utils::IRCopy(indices.back());
cinn::ir::ir_utils::IrReplace(&next_idx, Expr(iter_var_), Expr(i));
auto gap = common::AutoSimplify(Expr(next_idx - first_idx));
if (!gap.As<IntImm>() || gap.as_int32() != i) {
VLOG(5) << "Tensor:" << tensor->name
......@@ -185,7 +186,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
const Var iter_var_; // the loop var of the vecotrized loop
const int factor_; // the factor for vectorize
TensorWriteTeller write_teller_;
std::set<std::string> write_teller_;
TensorVectorizeTeller vectorized_teller_;
absl::flat_hash_map<std::string, Var> tensor2vectorized_vars_;
......@@ -215,7 +216,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
}
void Visit(Expr *expr) {
write_teller_.Collect(expr);
write_teller_ = ir::ir_utils::CollectTensorNeedsWrite(expr);
vectorized_teller_.Collect(expr);
IRMutator<Expr *>::Visit(expr, expr);
}
......@@ -289,7 +290,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
const std::vector<Expr> &indices,
bool is_store) {
auto *node = tensor.As<ir::_Tensor_>();
bool is_const = !write_teller_.IsWrite(node->name);
bool is_const = !write_teller_.count(node->name);
// generate the corresponding vector type
Type scalar_type = tensor->type().ElementOf();
......@@ -309,7 +310,8 @@ class CudaVectorizer : public IRMutator<Expr *> {
// generate a get_addr expr to get the address of the tensor
Expr converted_tensor = Load::Make(tensor, indices);
optim::IrReplace(&converted_tensor, iter_var_, Expr(int32_t(0)));
cinn::ir::ir_utils::IrReplace(
&converted_tensor, iter_var_, Expr(int32_t(0)));
auto get_addr = ir::intrinsics::GetAddr::Make(converted_tensor);
// generate a let expression to cast the tensor into the local vector
......@@ -798,7 +800,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
cuda_vectorizer.Visit(&new_forloop->body);
// unroll the new forloop to compute each element of the vector
// iteratively
auto copied_loop = optim::IRCopy(_new_forloop);
auto copied_loop = ir::ir_utils::IRCopy(_new_forloop);
copied_loop.As<ir::For>()->set_unrolled();
optim::UnrollLoop(&copied_loop);
// add cast exprs of vector type in the front of vectorized forloop,
......@@ -881,13 +883,14 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
Var new_iterator_outer(
common::UniqName(outer_for->loop_var->name + "_s"));
Expr inner_for_b = Block::Make({For::Make(new_iterator_inner,
inner_for->min,
b,
ForType::Serial,
DeviceAPI::UNK,
IRCopy(inner_for->body))});
optim::IrReplace(
Expr inner_for_b =
Block::Make({For::Make(new_iterator_inner,
inner_for->min,
b,
ForType::Serial,
DeviceAPI::UNK,
ir::ir_utils::IRCopy(inner_for->body))});
cinn::ir::ir_utils::IrReplace(
&inner_for_b, inner_for->loop_var, Expr(new_iterator_inner));
Expr out_for_b = For::Make(new_iterator_outer,
......@@ -897,7 +900,7 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
outer_for->device_api,
inner_for_b,
outer_for->vectorize_info());
optim::IrReplace(
cinn::ir::ir_utils::IrReplace(
&out_for_b, outer_for->loop_var, Expr(new_iterator_outer));
*expr = Block::Make({out_for_a, out_for_b});
VLOG(2) << *expr;
......@@ -959,7 +962,8 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
} else {
new_index = Expr(forloop->loop_var) * factor + Expr(new_iterator);
}
optim::IrReplace(&forloop->body, forloop->loop_var, new_index);
cinn::ir::ir_utils::IrReplace(
&forloop->body, forloop->loop_var, new_index);
auto new_forloop = For::Make(new_iterator,
forloop->min,
make_const(factor),
......
......@@ -14,7 +14,7 @@
#pragma once
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace cinn {
namespace optim {
......
......@@ -20,7 +20,7 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/poly/domain_add_unit_loop_mutator.h"
#include "paddle/cinn/poly/isl_utils.h"
......
......@@ -22,7 +22,7 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/placeholder.h"
......
......@@ -14,7 +14,7 @@
#include "paddle/cinn/poly/dim.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/utils/string.h"
......
......@@ -23,7 +23,7 @@
#include <unordered_set>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......@@ -70,8 +70,8 @@ void Domain::ExtractParams() {
std::unordered_set<std::string> var_names;
auto collect_param_fn = [&](Expr& e) {
if (!e.is_constant()) {
auto vars =
ir::CollectIRNodes(e, [](const Expr* e) { return e->is_var(); });
auto vars = ir::ir_utils::CollectIRNodes(
e, [](const Expr* e) { return e->is_var(); });
for (auto& var : vars) var_names.insert(var.As<ir::_Var_>()->name);
}
};
......
......@@ -20,7 +20,7 @@
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment