Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
......@@ -18,10 +18,11 @@
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace optim {
namespace ir {
namespace ir_utils {
//! Replace the variable \p v to expression \p e in expression \p expr.
void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to);
} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -14,10 +14,13 @@
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn::ir {
namespace cinn {
namespace ir {
namespace ir_utils {
namespace {
struct IrVerifyVisitor : public ir::IRMutator<> {
using ir::IRMutator<>::Visit;
......@@ -30,10 +33,11 @@ struct IrVerifyVisitor : public ir::IRMutator<> {
NODETY_FORALL(__)
#undef __
};
} // namespace
void IrVerify(Expr e) {
IrVerifyVisitor visitor;
visitor.Visit(&e, &e);
}
} // namespace cinn::ir
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -15,8 +15,11 @@
#pragma once
#include "paddle/cinn/ir/ir.h"
namespace cinn::ir {
namespace cinn {
namespace ir {
namespace ir_utils {
void IrVerify(Expr e);
} // namespace cinn::ir
} // namespace ir_utils
} // namespace ir
} // namespace cinn
......@@ -7,6 +7,8 @@ gather_srcs(
compute.cc
placeholder.cc
lower.cc
lower_impl.cc
lower_tensor_group.cc
builtin.cc
lower_impl.cc
packed_func.cc)
......
......@@ -22,14 +22,16 @@
#include <utility>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/lang/lower_impl.h"
#include "paddle/cinn/lang/lower_tensor_group.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace lang {
using ast_gen_ius::TensorGroup;
using ir::Tensor;
using poly::Stage;
......@@ -38,7 +40,7 @@ std::vector<ir::Argument> GetArgs(
std::vector<ir::Argument> res;
std::map<std::string, std::set<const ir::Load*>> name2loads;
std::map<std::string, std::set<const ir::Store*>> name2stores;
auto load_or_store_nodes = ir::CollectIRNodesWithoutTensor(
auto load_or_store_nodes = ir::ir_utils::CollectIRNodesWithoutTensor(
func_body,
[&](const Expr* x) { return x->As<ir::Store>() || x->As<ir::Load>(); });
......@@ -84,6 +86,49 @@ std::vector<ir::Argument> GetArgs(
return res;
}
//! Collect the temporary tensors from a computational graph.
std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
const TensorGroup& tensor_group,
Expr body) {
std::unordered_set<std::string> tensor_arg_names;
std::unordered_set<std::string> buffer_arg_names;
for (auto& tensor : tensor_args) {
tensor_arg_names.insert(tensor->name);
if (tensor->buffer.defined()) {
buffer_arg_names.insert(tensor->buffer->name);
}
}
std::map<std::string, ir::Buffer>
name_to_buffer; // used to avoid duplication.
auto all_temp_tensors =
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() &&
(!tensor_group.Contain(x->as_tensor()->name) ||
((!buffer_arg_names.count(x->as_tensor()->buffer->name) &&
!tensor_arg_names.count(x->as_tensor()->name)) ||
utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")));
});
for (auto& e : all_temp_tensors) {
auto buffer_name = e.as_tensor()->buffer->name;
if (!name_to_buffer.count(buffer_name)) {
name_to_buffer[buffer_name] = e.as_tensor()->buffer;
} else {
// Just copy from old code, but why?
if (e.as_tensor()->buffer->numel() <
name_to_buffer[buffer_name]->numel()) {
name_to_buffer[buffer_name] = e.as_tensor()->buffer;
}
}
}
std::vector<ir::Buffer> temp_buffers;
for (auto& i : name_to_buffer) {
temp_buffers.push_back(i.second);
}
return temp_buffers;
}
//! Collect the temporary tensors from a computational graph.
std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
const poly::StageMap& stage_map,
......@@ -100,7 +145,7 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
name_to_buffer; // used to avoid duplication.
auto all_temp_tensors =
ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() &&
(!stage_map->Lookup(x->as_tensor()->name) ||
!stage_map[x->as_tensor()]->inlined()) &&
......@@ -120,7 +165,8 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
}
}
// visit the ir body and update the map of name_to_buffer
auto update_map = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
auto update_map =
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
if (x->as_tensor() && x->as_tensor()->buffer.defined()) {
auto buffer_name = x->as_tensor()->buffer->name;
if (name_to_buffer.count(buffer_name) &&
......@@ -150,7 +196,7 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<ir::Argument>& args,
name_to_buffer; // used to avoid duplication.
auto all_temp_tensors =
ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() &&
(!buffer_arg_names.count(x->as_tensor()->buffer->name) ||
utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer"));
......@@ -167,7 +213,8 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<ir::Argument>& args,
}
}
// visit the ir body and update the map of name_to_buffer
auto update_map = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
auto update_map =
ir::ir_utils::CollectIRNodesWithoutTensor(body, [&](const Expr* x) {
if (x->as_tensor() && x->as_tensor()->buffer.defined()) {
auto buffer_name = x->as_tensor()->buffer->name;
if (name_to_buffer.count(buffer_name) &&
......@@ -205,7 +252,7 @@ void InitReduceTensor(StageMap stages,
tensor->InitReduction(stages, target);
}
auto uninited_reduce_tensors =
ir::CollectIRNodes(tensor->body(), [&](const Expr* x) {
ir::ir_utils::CollectIRNodes(tensor->body(), [&](const Expr* x) {
return x && x->defined() && x->as_tensor() &&
x->as_tensor()->is_reduce_tensor() &&
!x->as_tensor()->IsReduceInited(stages);
......@@ -216,6 +263,57 @@ void InitReduceTensor(StageMap stages,
}
}
std::set<ir::Tensor> CollectTempTensorsFromCtrlDepends(
ast_gen_ius::TensorGroup* tensor_group,
const std::vector<Tensor>& tensor_args) {
std::set<ir::Tensor> res;
for (const ir::Tensor& a : tensor_group->GetAllTensors()) {
for (const ir::Tensor& t : tensor_group->GetCrtlDepTensors(a->name)) {
res.emplace(t);
}
}
for (const ir::Tensor& t : tensor_args) {
if (res.count(t)) {
res.erase(t);
}
}
return res;
}
ir::LoweredFunc LowerToAst(const std::string& name,
const std::vector<Tensor>& tensor_args,
ast_gen_ius::TensorGroup* tensor_group,
const Target& target) {
std::vector<ir::LoweredFunc> result =
LowerToAstVec(name, tensor_args, tensor_group, target);
CHECK_EQ(result.size(), 1UL) << "LowerToAst contains not only 1 LoweredFunc, "
"use LowerToAstVec instead.";
return result[0];
}
std::vector<ir::LoweredFunc> LowerToAstVec(
const std::string& name,
const std::vector<Tensor>& tensor_args,
ast_gen_ius::TensorGroup* tensor_group,
const Target& target) {
std::set<ir::Tensor> ctrl_deps =
CollectTempTensorsFromCtrlDepends(tensor_group, tensor_args);
auto lower_instance = detail::LowerTensorGroup(
name,
tensor_args,
{},
tensor_group,
std::vector<Tensor>(ctrl_deps.begin(), ctrl_deps.end()),
target);
std::vector<ir::LoweredFunc> result = lower_instance();
for (auto& res : result) {
if (target == common::DefaultNVGPUTarget()) {
res->device_api = ir::DeviceAPI::GPU;
}
}
return result;
}
ir::LoweredFunc Lower(const std::string& name,
StageMap stages,
const std::vector<Tensor>& tensor_args,
......
......@@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
......@@ -73,6 +74,22 @@ std::vector<ir::LoweredFunc> LowerVec(
const Target &target = common::DefaultHostTarget(),
bool support_ir_schedule = false);
ir::LoweredFunc LowerToAst(const std::string &name,
const std::vector<Tensor> &tensor_args,
ast_gen_ius::TensorGroup *tensor_group,
const Target &target = common::DefaultHostTarget());
std::vector<ir::LoweredFunc> LowerToAstVec(
const std::string &name,
const std::vector<Tensor> &tensor_args,
ast_gen_ius::TensorGroup *tensor_group,
const Target &target = common::DefaultHostTarget());
std::vector<ir::Buffer> GetTempBuffers(
const std::vector<Tensor> &tensor_args,
const ast_gen_ius::TensorGroup &tensor_group,
Expr body);
std::vector<ir::Argument> GetArgs(
const Expr &func_body, const std::vector<std::string> &input_output_nodes);
......
......@@ -23,9 +23,9 @@
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/stage.h"
......@@ -35,7 +35,7 @@ namespace lang {
namespace detail {
void CheckNoIslCallRemains(Expr* expr) {
auto isl_calls = ir::CollectIRNodes(*expr, [](const Expr* expr) {
auto isl_calls = ir::ir_utils::CollectIRNodes(*expr, [](const Expr* expr) {
return expr->As<ir::Call>() && expr->As<ir::Call>()->is_isl_call();
});
#ifdef CINN_DEBUG
......@@ -223,7 +223,7 @@ void CreateCompGraphWithInlineTensors(common::Graph* graph,
// collect dependency tensors of t
// here we just collect the tensors in Load nodes
// NOTE there may be some other cases.
auto deps = ir::CollectLoadTensors(
auto deps = ir::ir_utils::CollectLoadTensors(
t->body(), [](const Expr* x) { return x->as_tensor(); });
for (const auto& dep : deps) {
auto e_tensor = dep.as_tensor_ref();
......@@ -342,8 +342,7 @@ std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(
CheckArgsUnique();
std::vector<ir::Argument> args;
optim::TensorWriteTeller teller;
teller.Collect(&fn_body);
auto teller = ir::ir_utils::CollectTensorNeedsWrite(&fn_body);
std::set<std::string> arg_names;
......@@ -358,7 +357,7 @@ std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(
for (auto& tensor : tensor_args_) {
auto* tensor_node = tensor.As<ir::_Tensor_>();
bool is_output = teller.IsWrite(tensor->name);
bool is_output = teller.count(tensor->name);
VLOG(1) << "tensor argument " << tensor->name << " buffer "
<< tensor->buffer->name;
......@@ -396,8 +395,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
std::vector<ir::Argument> in_args;
std::vector<ir::Argument> out_args;
optim::TensorWriteTeller teller;
teller.Collect(&func_iterator);
auto teller = ir::ir_utils::CollectTensorNeedsWrite(&func_iterator);
std::set<std::string> arg_names;
std::set<std::string> all_tensor_names;
......@@ -410,11 +408,12 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
in_args.emplace_back(scalar, ir::Argument::IO::kInput);
}
auto all_tensors = ir::CollectIRNodes(func_iterator, [&](const Expr* x) {
auto all_tensors =
ir::ir_utils::CollectIRNodes(func_iterator, [&](const Expr* x) {
return x->as_tensor() && !stages_[x->as_tensor()]->inlined();
});
auto all_vars = ir::CollectIRNodes(
auto all_vars = ir::ir_utils::CollectIRNodes(
func_iterator, [&](const Expr* x) { return x->as_var(); });
for (auto& i : all_tensors) {
......@@ -448,7 +447,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
VLOG(3) << "In tensor_args_, it has : " << tensor->name;
if (temp_tensor_names.count(tensor->name) > 0) continue;
if (all_tensor_names.count(tensor->name) == 0) continue;
bool is_output = teller.IsWrite(tensor->name);
bool is_output = teller.count(tensor->name);
VLOG(3) << "tensor argument " << tensor->name << " buffer "
<< tensor->buffer->name;
......@@ -485,7 +484,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
VLOG(3) << "Tensor " << tensor->name;
if (tensor->buffer.defined() && !arg_names.count(tensor->buffer->name)) {
bool is_output =
teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name);
teller.count(tensor->name) && teller.count(tensor->name);
if (is_output)
out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
}
......@@ -590,7 +589,7 @@ std::vector<ir::LoweredFunc> LowerImpl::operator()() {
Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer;
}
}
auto store_exprs = ir::CollectIRNodes(
auto store_exprs = ir::ir_utils::CollectIRNodes(
func_iterator, [](const Expr* x) { return x->As<ir::Store>(); });
std::vector<ir::Tensor> new_temp_tensors;
for (auto& expr : store_exprs) {
......@@ -655,7 +654,7 @@ std::vector<ir::LoweredFunc> LowerImpl::operator()() {
if (support_ir_schedule_) {
optim::TransformPolyForToFor(&func->body);
optim::RemoveNestedBlock(&func->body);
optim::SimplifyBlocks(&func->body);
func->body = ir::Block::Make({func->body});
result.push_back(ir::LoweredFunc(func.get()));
num_func++;
......
......@@ -27,14 +27,13 @@
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/optim/compute_inline_expand.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/ast_gen.h"
......
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/lang/lower_tensor_group.h"
#include <algorithm>
#include <queue>
#include <string>
#include <unordered_set>
#include "paddle/cinn/ast_gen_ius/ast_gen.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace lang {
namespace detail {
LowerTensorGroup::LowerTensorGroup(
const std::string& fn_name,
const std::vector<ir::Tensor>& tensor_args,
const std::vector<ir::Var>& scalar_args,
ast_gen_ius::TensorGroup* tensor_group,
const std::vector<ir::Tensor>& temp_tensor_args,
const Target& target)
: fn_name_(fn_name),
tensor_args_(tensor_args),
scalar_args_(scalar_args),
tensor_group_(tensor_group),
temp_tensor_args_(temp_tensor_args),
target_(target) {}
std::vector<ir::LoweredFunc> LowerTensorGroup::operator()() {
std::vector<ir::LoweredFunc> result;
int num_func = 0;
// 1. Generate function body
std::vector<ir::Expr> func_bodies = GenerateFunctionBody(tensor_group_);
for (ir::Expr& func_body : func_bodies) {
func_body = ir::ScheduleBlockRealize::Make(
{},
ir::ScheduleBlock::Make(
{}, {}, {}, common::UniqName("root"), func_body));
// 2. Assign buffer to tensors
auto tensor_map = tensor_group_->AllocateBuffers();
// copy the tensor(with buffer assigned) back to func's args.
for (auto& arg : tensor_args_) {
if (arg->is_placeholder_node() || arg->buffer.defined()) {
continue;
}
if (arg->body().As<ir::Call>() && arg->body().type().is_void()) {
continue; // extern call
}
if (tensor_map.find(arg->name) == tensor_map.end()) {
LOG(INFO) << "Didn't find arg tensor " << arg->name
<< "in tensor_map.\n"
<< "The function is " << fn_name_
<< "\nAnd all the arg tensors are:\n";
for (auto& i : tensor_args_) {
LOG(INFO) << i->name;
}
LOG(FATAL) << "Fatal Error!";
}
Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer;
}
// 3. Collect temp tensor buffers
std::set<std::string> temp_tensor_names;
for (auto& t : temp_tensor_args_) {
temp_tensor_names.insert(t->name);
}
// Some store tensors are also temp tensors;
auto store_exprs = ir::ir_utils::CollectIRNodes(
func_body, [](const Expr* x) { return x->As<ir::Store>(); });
for (auto& expr : store_exprs) {
auto* store_node = expr.As<ir::Store>();
CHECK(store_node);
auto* tensor = store_node->tensor.As<ir::_Tensor_>();
CHECK(tensor);
VLOG(3) << "In store_exprs, its name is : " << tensor->name;
CHECK(tensor->buffer.defined());
if (tensor->buffer->memory_type != ir::MemoryType::Heap) {
temp_tensor_names.insert(store_node->tensor.as_tensor_ref()->name);
}
}
std::vector<ir::Buffer> temp_buffers;
std::unordered_set<std::string> buffer_name_set;
for (const std::string& name : temp_tensor_names) {
if (!tensor_map.count(name)) {
continue;
}
ir::Tensor& t = tensor_map[name];
if (t->buffer.defined() && !buffer_name_set.count(t->buffer->name)) {
temp_buffers.push_back(t->buffer);
buffer_name_set.insert(t->buffer->name);
}
}
// 4. Handle function args
std::vector<ir::Argument> func_args =
GenerateFunctionArgumentList(func_body);
// 5. Actual function make
std::string actual_fn_name = fn_name_;
if (num_func > 0) {
actual_fn_name += "_" + std::to_string(num_func);
VLOG(3) << "Making func :" << actual_fn_name;
}
for (auto& i : func_args) {
VLOG(3) << "func_args is : " << i.name();
}
for (auto& i : temp_buffers) {
VLOG(3) << "temp_buffers is : " << i->name;
}
ir::LoweredFunc func = ir::_LoweredFunc_::Make(
actual_fn_name, func_args, func_body, temp_buffers);
// 6. Final clean up
optim::SimplifyBlocks(&func->body);
func->body = ir::Block::Make({func->body});
result.push_back(ir::LoweredFunc(func.get()));
num_func++;
}
return result;
}
std::vector<ir::Argument> LowerTensorGroup::GenerateFunctionArgumentList(
Expr fn_body) {
std::vector<ir::Argument> args;
auto teller = ir::ir_utils::CollectTensorNeedsWrite(&fn_body);
std::set<std::string> arg_names;
for (auto& scalar : scalar_args_) {
CHECK(!arg_names.count(scalar->name));
auto* scalar_node = scalar.As<ir::_Var_>();
CHECK(scalar_node->type().valid());
arg_names.insert(scalar->name);
args.emplace_back(scalar, ir::Argument::IO::kInput);
}
for (auto& tensor : tensor_args_) {
auto* tensor_node = tensor.As<ir::_Tensor_>();
bool is_output = teller.count(tensor->name);
VLOG(6) << "tensor argument " << tensor->name << ", buffer "
<< tensor->buffer->name << ", is output: " << is_output;
// avoid duplicate
if (!tensor_node->buffer.defined()) {
continue;
}
// if a argument is already marked as kInput, mark it as kOutput and move
// it to the back.
if (arg_names.count(tensor_node->buffer->name)) {
auto it =
std::find_if(args.begin(), args.end(), [&](const ir::Argument& x) {
return x.name() == tensor_node->buffer->name;
});
CHECK(it != args.end());
if (it->is_input()) {
args.erase(it);
} else if (it->is_output()) {
continue;
}
}
arg_names.insert(tensor_node->buffer->name);
auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput;
VLOG(6) << "Collect " << (is_output ? "W" : "R") << " argument "
<< tensor->buffer->name;
args.emplace_back(tensor_node->buffer, io);
}
return args;
}
std::vector<ir::Expr> LowerTensorGroup::GenerateFunctionBody(
ast_gen_ius::TensorGroup* tensor_group) {
// TODO(zhhsplendid): GetGenFuncTopoOrder() may remove args
std::vector<ir::Tensor> ordered_tensors = tensor_group->GetGenFuncTopoOrder();
std::vector<ir::Expr> result;
std::vector<ir::Expr> bodies;
for (const ir::Tensor& tensor : ordered_tensors) {
VLOG(6) << "tensor_name = " << tensor->name;
if (!tensor->is_placeholder_node() && tensor->has_expression()) {
VLOG(6) << "ast_gen_ius::AstGen::Build for Tensor " << tensor;
bodies.emplace_back(ast_gen_ius::AstGen::Build(tensor, tensor_group));
bool gpu_local =
tensor->buffer.defined() &&
(tensor->buffer->memory_type == ir::MemoryType::GPUShared ||
tensor->buffer->memory_type == ir::MemoryType::GPULocal);
if (target_ == common::DefaultNVGPUTarget() && !gpu_local) {
result.push_back(bodies.size() == 1 ? bodies[0]
: ir::Block::Make(bodies));
bodies.clear();
}
}
}
if (!bodies.empty()) {
result.push_back(bodies.size() == 1 ? bodies[0] : ir::Block::Make(bodies));
bodies.clear();
}
return result;
}
} // namespace detail
} // namespace lang
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/optim/compute_inline_expand.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/ast_gen.h"
namespace cinn {
namespace lang {
namespace detail {
class LowerTensorGroup {
public:
LowerTensorGroup(const std::string& fn_name,
const std::vector<ir::Tensor>& tensor_args,
const std::vector<ir::Var>& scalar_args,
ast_gen_ius::TensorGroup* tensor_group,
const std::vector<ir::Tensor>& temp_tensor_args = {},
const Target& target = common::DefaultHostTarget());
std::vector<ir::LoweredFunc> operator()();
std::vector<ir::Expr> GenerateFunctionBody(
ast_gen_ius::TensorGroup* tensor_group);
std::vector<ir::Argument> GenerateFunctionArgumentList(ir::Expr fn_body);
private:
const std::string& fn_name_;
const std::vector<ir::Tensor>& tensor_args_;
const std::vector<Var>& scalar_args_;
std::vector<ir::Tensor> temp_tensor_args_;
ast_gen_ius::TensorGroup* tensor_group_;
Target target_;
};
} // namespace detail
} // namespace lang
} // namespace cinn
......@@ -18,6 +18,7 @@
#include <set>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/buffer.h"
#include "paddle/cinn/lang/compute.h"
......@@ -27,6 +28,10 @@
namespace cinn {
namespace lang {
#define TEST_SOUTPUT(x, out) \
LOG(INFO) << "\n" << x << std::endl; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
TEST(lower, basic) {
auto M = Expr(100);
auto N = Expr(15);
......@@ -42,10 +47,6 @@ TEST(lower, basic) {
LOG(INFO) << "lower_size " << lower_funcs;
#define TEST_SOUTPUT(x, out) \
std::cout << "\n" << x << std::endl; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
auto out = R"ROC(
{
serial for (i, 0, 100)
......@@ -77,7 +78,7 @@ TEST(lower, more_complex) {
auto lower_funcs = Lower("cal_C", stages, {A, B, C});
std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl;
LOG(INFO) << "func:\n" << Expr(lower_funcs->self()) << std::endl;
}
//! To support training, the dynamic shape support is vital. We test the
......@@ -157,5 +158,135 @@ TEST(lower, temp_buffer_collects) {
}
}
TEST(lower_to_ast, basic) {
Context::Global().ResetNameId();
auto M = Expr(100);
auto N = Expr(15);
Placeholder<float> A("A", {Expr(M), Expr(N)});
ir::Tensor B = Compute(
{M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B");
ast_gen_ius::TensorGroup tensor_group({B});
ir::LoweredFunc lower_func = LowerToAst("cal_B", {A, B}, &tensor_group);
LOG(INFO) << "lower_func " << lower_func;
auto out = R"ROC(
function cal_B (_A, _B)
{
ScheduleBlock(root)
{
serial for (i, 0, 100)
{
serial for (j, 0, 15)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 1.00000000f)
}
}
}
}
}
)ROC";
TEST_SOUTPUT(lower_func, out);
}
TEST(lower_to_ast, three_dim) {
Context::Global().ResetNameId();
Expr M(100);
Expr N(15);
Expr K(200);
Placeholder<float> A("A", {Expr(M), Expr(N)});
Placeholder<float> B("B", {Expr(N), Expr(K)});
auto C = Compute(
{M, N, K},
[=](Var i, Var j, Var k) -> Expr { return A(i, j) * B(j, k); },
"C");
ast_gen_ius::TensorGroup tensor_group({C});
ir::LoweredFunc lower_func = LowerToAst("cal_C", {A, B, C}, &tensor_group);
LOG(INFO) << "func:\n" << lower_func << std::endl;
auto out = R"ROC(
function cal_C (_A, _B, _C)
{
ScheduleBlock(root)
{
serial for (i, 0, 100)
{
serial for (j, 0, 15)
{
serial for (k, 0, 200)
{
ScheduleBlock(C)
{
i0, i1, i2 = axis.bind(i, j, k)
C[i0, i1, i2] = (A[i0, i1] * B[i1, i2])
}
}
}
}
}
}
)ROC";
TEST_SOUTPUT(lower_func, out);
}
TEST(lower_to_ast, matmul_with_reduce_sum) {
Context::Global().ResetNameId();
Placeholder<float> A("A", {Expr(100), Expr(20)});
Placeholder<float> B("B", {Expr(20), Expr(50)});
Target target{};
// C = A * B
Var k(20, "k0");
Tensor C = Compute(
{Expr(100), Expr(50)},
[&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
ast_gen_ius::TensorGroup tensor_group({C});
ir::LoweredFunc lower_func = LowerToAst("matmul", {A, B, C}, &tensor_group);
LOG(INFO) << "func:\n" << lower_func << std::endl;
auto out = R"ROC(
function matmul (_A, _B, _C)
{
ScheduleBlock(root)
{
serial for (i, 0, 100)
{
serial for (j, 0, 50)
{
ScheduleBlock(C__reduce_init)
{
i0, i1 = axis.bind(i, j)
C__reduce_init[i0, i1] = 0.00000000f
}
serial for (k0, 0, 20)
{
ScheduleBlock(C)
{
i0_0, i1_0, i2 = axis.bind(i, j, k0)
C[i0_0, i1_0] = (C[i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0]))
}
}
}
}
}
}
)ROC";
TEST_SOUTPUT(lower_func, out);
}
} // namespace lang
} // namespace cinn
......@@ -16,8 +16,8 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
......
......@@ -19,9 +19,9 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace cinn {
......
......@@ -16,7 +16,7 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace cinn {
namespace lang {
......
......@@ -3,11 +3,8 @@ core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
remove_nested_block.cc
replace_call_with_expr.cc
ir_replace.cc
replace_var_with_expr.cc
tensor_write_tell.cc
ir_simplify.cc
optimize.cc
vectorize_loops.cc
......@@ -23,19 +20,16 @@ gather_srcs(
compute_inline_expand.cc
buffer_assign.cc
replace_const_param_to_integer.cc
cast_simplify.cc
lower_intrin.cc
cast_bool_to_int8.cc
collect_undefined_vars.cc
var_mod_simplify.cc
remove_schedule_block.cc)
remove_schedule_block.cc
replace_cross_thread_reduction.cc)
if(WITH_CUDA)
gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc)
endif()
cinn_cc_test(test_remove_nested_block SRCS remove_nested_block_test.cc DEPS
cinncore)
cinn_cc_test(test_ir_simplify SRCS ir_simplify_test.cc DEPS cinncore)
cinn_cc_test(test_replace_call_with_expr SRCS replace_call_with_expr_test.cc
DEPS cinncore)
......@@ -62,3 +56,5 @@ cinn_cc_test(test_cast_simplify SRCS cast_simplify_test.cc DEPS cinncore)
cinn_cc_test(test_remove_schedule_block SRCS remove_schedule_block_test.cc DEPS
cinncore)
cinn_cc_test(test_unroll_loops SRCS unroll_loops_test.cc DEPS cinncore)
cinn_cc_test(test_replace_cross_thread_reduction SRCS
replace_cross_thread_reduction_test.cc DEPS cinncore)
......@@ -15,10 +15,10 @@
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/common/union_find.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
#include "paddle/cinn/lang/lower_impl.h"
#include "paddle/cinn/optim/ir_replace.h"
namespace cinn {
namespace optim {
......@@ -73,7 +73,7 @@ std::map<std::string, ir::Tensor> InitialAssignBuffer(
// unify all the tensor occurance with a global one, e.g. there are multiple
// tensor B exists in the expression, replace them with a shared one.
ir::CollectIRNodes(*expr, [&](const Expr* x) -> bool {
ir::ir_utils::CollectIRNodes(*expr, [&](const Expr* x) -> bool {
auto* t = x->as_tensor();
if (t && !stages[t]->inlined()) {
Reference(x) = Expr(all_tensor_map.at(t->name));
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace cinn {
......
......@@ -16,7 +16,7 @@
#include <glog/logging.h>
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace cinn::optim {
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace cinn::optim {
using cinn::common::bfloat16;
using cinn::common::float16;
namespace {
template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}
if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}
struct Mutator : ir::IRMutator<> {
using ir::IRMutator<>::Visit;
void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();
Visit(&node->v(), &node->v());
if (op->type() == op->v().type()) {
*expr = op->v();
return;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if (op->v().is_constant()) {
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<bfloat16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(bfloat16)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
} // namespace
void CastSimplify(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
}
} // namespace cinn::optim
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir.h"
namespace cinn::optim {
/**
* Simplify the Cast nodes.
*
* There are several patterns:
* 1. the source and target type are the same, drop the Cast node
* 2. for intermediate numbers, just replace the Cast node with a Node of the
* target type
*/
void CastSimplify(Expr* e);
} // namespace cinn::optim
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment