Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef CINN_WITH_CUDA
#if defined(__linux__)
#include <sys/stat.h>
#endif
#include <glog/logging.h>
#include <string>
#include <vector>
namespace cinn {
namespace backends {
namespace nvrtc {
/**
* An helper class to call NVRTC. Input CUDA device source code, get PTX string.
*/
class Compiler {
public:
Compiler();
/**
* Compile the \p code and get PTX string.
* @param code The CUDA source code.
* @param include_headers Whether to include the headers of CUDA and CINN
* runtime modules.
* @return Compiled PTX code string.
*/
std::string operator()(const std::string& code, bool include_headers = true);
/** Compile into cubin or not
* @return Compile into cubin or not.
*/
bool compile_to_cubin();
private:
/**
* Get the directories of CUDA's header files.
* @return list of header file directories.
*/
std::vector<std::string> FindCUDAIncludePaths();
/**
* Get the directories of CINN runtime's header files.
* @return list of header file directories.
*/
std::vector<std::string> FindCINNRuntimeIncludePaths();
/**
* Compile CUDA source code and get PTX or CUBIN.
* @param code source code string.
* @return PTX or CUBIN string.
*/
std::string CompileCudaSource(const std::string& code, bool include_headers);
/**
* whether to compile the source code into cubin, only works with cuda version
* > 11.1
*/
bool compile_to_cubin_{false};
// compile with nvcc
std::string CompileWithNvcc(const std::string&);
// compile to ptx
void CompileToPtx();
// compile to cubin
void CompileToCubin();
std::string GetDeviceArch();
std::string ReadFile(const std::string&, std::ios_base::openmode);
std::string prefix_name_{""};
};
} // namespace nvrtc
} // namespace backends
} // namespace cinn
#endif // CINN_WITH_CUDA
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include <gtest/gtest.h>
namespace cinn {
namespace backends {
namespace nvrtc {
TEST(Compiler, basic) {
Compiler compiler;
std::string source_code = R"ROC(
extern "C" __global__
void saxpy(float a, float *x, float *y, float *out, size_t n)
{
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
out[tid] = a * x[tid] + y[tid];
}
}
)ROC";
auto ptx = compiler(source_code);
LOG(INFO) << "ptx:\n" << ptx;
}
TEST(Compiler, float16) {
Compiler compiler;
std::string source_code = R"(
#include <cstdint>
#define CINN_WITH_CUDA
#include "float16.h"
using cinn::common::float16;
extern "C" __global__
void cast_fp32_to_fp16_cuda_kernel(const float* input, const int num, float16* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
out[idx] = float16(input[idx]);
}
}
)";
auto ptx = compiler(source_code);
LOG(INFO) << "ptx:\n" << ptx;
}
TEST(Compiler, bfloat16) {
Compiler compiler;
std::string source_code = R"(
#include <cstdint>
#define CINN_WITH_CUDA
#include "bfloat16.h"
using cinn::common::bfloat16;
extern "C" __global__
void cast_fp32_to_bf16_cuda_kernel(const float* input, const int num, bfloat16* out) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num) {
out[idx] = bfloat16(input[idx]);
}
}
)";
auto ptx = compiler(source_code);
LOG(INFO) << "ptx:\n" << ptx;
}
} // namespace nvrtc
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/outputs.h"
namespace cinn {
namespace lang {} // namespace lang
backends::Outputs backends::Outputs::object(const std::string &name) const {
Outputs updated = *this;
updated.object_name = name;
return updated;
}
backends::Outputs backends::Outputs::bitcode(const std::string &name) const {
Outputs updated = *this;
updated.bitcode_name = name;
return updated;
}
backends::Outputs backends::Outputs::c_header(const std::string &name) const {
Outputs updated = *this;
updated.c_header_name = name;
return updated;
}
backends::Outputs backends::Outputs::c_source(const std::string &name) const {
Outputs updated = *this;
updated.c_source_name = name;
return updated;
}
backends::Outputs backends::Outputs::cuda_source(
const std::string &name) const {
Outputs updated = *this;
updated.cuda_source_name = name;
return updated;
}
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace cinn {
namespace backends {
/**
* A struct specifying a collection of outputs.
*/
struct Outputs {
//! The name of the emitted object file. Empty if no object file is desired.
std::string object_name;
//! The name of the emitted llvm bitcode. Empty if no bitcode file is desired.
std::string bitcode_name;
//! The name of the emitted C header file.
std::string c_header_name;
//! The name of the emitted C source file.
std::string c_source_name;
//! The name of the emitted CUDA source file.
std::string cuda_source_name;
Outputs object(const std::string& name) const;
Outputs bitcode(const std::string& name) const;
Outputs c_header(const std::string& name) const;
Outputs c_source(const std::string& name) const;
Outputs cuda_source(const std::string& name) const;
};
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/utils/timer.h"
__global__ void elementwise_add_kernel(const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C) {
if ((blockIdx.x < 1024)) {
{
if ((threadIdx.x < 1024)) {
{
C[((1024 * blockIdx.x) + threadIdx.x)] =
(A[((1024 * blockIdx.x) + threadIdx.x)] +
B[((1024 * blockIdx.x) + threadIdx.x)]);
}
}
}
}
}
TEST(raw_cuda, basic) {
const int M = 1024;
const int N = 1024;
// allocate CUDA buffer
float *Ag, *Bg, *Cg;
const int num_bytes = M * N * sizeof(float);
cudaMalloc(&Ag, num_bytes);
cudaMalloc(&Bg, num_bytes);
cudaMalloc(&Cg, num_bytes);
cinn::utils::Timer timer;
timer.Start();
for (int i = 0; i < 1000; i++) {
elementwise_add_kernel<<<1024, 1024>>>(Ag, Bg, Cg);
}
CUDA_CALL(cudaDeviceSynchronize());
float latency = timer.Stop();
LOG(INFO) << "latency: " << latency / 1000;
}
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file exposes some internal APIs to global cinn namespace to make usage
* more friendly.
*/
#pragma once
#include "paddle/cinn/backends/codegen_c.h"
#include "paddle/cinn/backends/codegen_c_x86.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/optim/optimize.h"
namespace cinn {
using backends::CodeGenC;
using backends::CodeGenCX86;
using backends::Outputs;
using ir::Module;
using ir::Var;
using lang::Buffer;
using lang::CallExtern;
using lang::CallLowered;
using lang::Compute;
using lang::Lower;
using lang::Placeholder;
using lang::ReduceAll;
using lang::ReduceAny;
using lang::ReduceMax;
using lang::ReduceMin;
using lang::ReduceMul;
using lang::ReduceSum;
using optim::Optimize;
using poly::CreateStages;
using lang::logic_and;
using lang::logic_or;
using common::Target;
} // namespace cinn
core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
shared.cc
cinn_value.cc
type.cc
target.cc
object.cc
debug_manager.cc
info_registry.cc
graph_utils.cc
context.cc
axis.cc
ir_util.cc
test_helper.cc
# cuda_test_helper.cc
arithmatic.cc
cas.cc
union_find.cc
python_interpreter_guard.cc)
message(STATUS "srcs: ${cinnapi_src}")
cinn_cc_test(test_dfs_walker SRCS dfs_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_dfs_topo_walker SRCS dfs_topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_is_reachable_predicator SRCS is_reachable_predicator_test.cc
DEPS gtest glog)
cinn_cc_test(test_topo_walker SRCS topo_walker_test.cc DEPS gtest glog)
cinn_cc_test(test_cinn_value SRCS cinn_value_test.cc DEPS cinncore)
cinn_cc_test(test_shared SRCS shared_test.cc DEPS cinncore)
cinn_cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS cinncore)
cinn_cc_test(test_arithmatic SRCS arithmatic_test.cc DEPS cinncore)
cinn_cc_test(test_cas SRCS cas_test.cc DEPS cinncore)
cinn_cc_test(test_type SRCS type_test.cc DEPS cinncore)
cinn_cc_test(test_axis SRCS axis_test.cc DEPS cinncore)
cinn_cc_test(test_fp16_bf16_host SRCS float16_bfloat16_host_test.cc DEPS gtest
glog)
if(WITH_CUDA)
cinn_nv_test(test_fp16_bf16_cuda SRCS float16_bfloat16_cuda_test.cu DEPS
gtest glog)
endif()
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/arithmatic.h"
#include <map>
#include <mutex>
#include <numeric>
#include <set>
#include <string>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace common {
using utils::GetStreamCnt;
using utils::Join;
using utils::Replace;
using utils::Split;
using namespace ir; // NOLINT
#ifdef As
#undef As
#endif
std::string ExprToGinacConverter::Repr(const ir::Expr& expr) {
auto* load_n = expr.As<Load>();
auto* var_n = expr.As<_Var_>();
auto* broadcast_n = expr.As<Broadcast>();
auto* mod_n = expr.As<Mod>();
auto* min_n = expr.As<Min>();
auto* max_n = expr.As<Max>();
auto* div_n = expr.As<Div>();
auto* frac_n = expr.As<FracOp>();
if (load_n || broadcast_n || mod_n || min_n || max_n || div_n || frac_n) {
std::string repr = GetStreamCnt(expr);
Replace(&repr, "[", "lsq_");
Replace(&repr, "]", "_rsq");
Replace(&repr, "(", "lb_");
Replace(&repr, ")", "_rb");
Replace(&repr, "+", "_add_");
Replace(&repr, "-", "_sub_");
Replace(&repr, ":", "_ref_");
Replace(&repr, "*", "_mul_");
Replace(&repr, "/", "_div_");
// remove the spaces
auto fields = utils::Split(repr, " ");
repr = utils::Join(fields, "_");
return repr;
} else if (var_n) {
return utils::GetStreamCnt(expr);
}
return "";
}
void ExprToGinacConverter::RecordExpr(const ir::Expr& expr) {
repr_to_expr_[Repr(expr)] = expr;
}
GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) {
auto* load_n = expr.As<Load>();
auto* var_n = expr.As<_Var_>();
auto* int_n = expr.As<IntImm>();
auto* float_n = expr.As<FloatImm>();
auto* add_n = expr.As<Add>();
auto* sub_n = expr.As<Sub>();
auto* mul_n = expr.As<Mul>();
auto* div_n = expr.As<Div>();
auto* minus_n = expr.As<Minus>();
auto* broadcast_n = expr.As<Broadcast>();
auto* mod_n = expr.As<Mod>();
auto* frac_n = expr.As<FracOp>();
auto* min_n = expr.As<Min>();
auto* max_n = expr.As<Max>();
bool is_integer_math = expr.type().is_int();
bool is_invalid_arith =
load_n || var_n || broadcast_n || mod_n || min_n || max_n;
if (is_integer_math)
is_invalid_arith = is_invalid_arith || div_n ||
frac_n; // GiNac can't deal with integer division.
if (is_invalid_arith) {
RecordExpr(expr);
std::string repr = Repr(expr);
return CreateGinacSymbol(repr);
} else if (int_n) {
return int_n->value;
} else if (float_n) {
return float_n->value;
} else if (add_n) {
auto a = BuildHelper(add_n->a());
auto b = BuildHelper(add_n->b());
return (a + b) * 1;
} else if (sub_n) {
return (BuildHelper(sub_n->a()) - BuildHelper(sub_n->b()));
} else if (mul_n) {
return (BuildHelper(mul_n->a()) * BuildHelper(mul_n->b()));
} else if (div_n) {
return (BuildHelper(div_n->a()) / BuildHelper(div_n->b()));
} else if (frac_n) {
return (BuildHelper(frac_n->a()) / BuildHelper(frac_n->b()));
} else if (minus_n) {
return -BuildHelper(minus_n->v());
} else {
CINN_NOT_IMPLEMENTED
}
}
GiNaC::ex ExprToGinacConverter::operator()(Expr expr) {
// TODO(Superjomn) Replace this with common::IsPureMath(
auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) {
return n->As<Block>() || //
n->As<PolyFor>() || //
n->As<EQ>() || //
n->As<NE>() || //
n->As<LT>() || //
n->As<LE>() || //
n->As<GT>() || //
n->As<GE>() || //
n->As<And>() || //
n->As<Or>() || //
n->As<Not>() || //
n->As<Let>() || //
n->As<Call>() || //
n->As<Select>() || //
n->As<Store>() || //
n->As<Alloc>() || //
n->As<Free>() || //
n->As<IfThenElse>();
});
CHECK(complex_nodes.empty()) << "Ginac converter can only deal with simple "
"math expression, but get some complex nodes"
<< expr;
return BuildHelper(expr);
}
GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const std::string& repr) {
CHECK(!repr.empty());
auto it = repr_to_ginac_.find(repr);
if (it != repr_to_ginac_.end()) return it->second;
GiNaC::symbol x(repr);
repr_to_ginac_[repr] = x;
return x;
}
GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const ir::Expr& var) {
CHECK(var.As<_Var_>());
return CreateGinacSymbol(Repr(var));
}
class GiNaCToExprVisitor : public GiNaC::symbol::visitor,
public GiNaC::numeric::visitor,
public GiNaC::add::visitor,
public GiNaC::mul::visitor,
public GiNaC::power::visitor,
public GiNaC::basic::visitor,
public GiNaC::visitor {
std::map<std::string, ir::Expr>& repr_to_expr;
ir::Expr cur;
public:
explicit GiNaCToExprVisitor(
std::map<std::string, ir::Expr>& repr_to_expr) // NOLINT
: repr_to_expr(repr_to_expr) {}
Expr operator()(GiNaC::ex ex) {
ex.accept(*this);
return cur;
}
void visit(const GiNaC::symbol& node) override {
auto it = repr_to_expr.find(node.get_name());
CHECK(it != repr_to_expr.end())
<< "node [" << node.get_name() << "] not found";
cur = it->second;
}
void visit(const GiNaC::numeric& node) override {
if (node.is_integer()) {
cur = Expr(static_cast<int>(node.to_int()));
} else {
cur = Expr(static_cast<float>(node.to_double()));
}
}
void visit(const GiNaC::add& node) override {
node.op(0).accept(*this);
Expr res = cur;
for (int i = 1; i < node.nops(); i++) {
node.op(i).accept(*this);
res = res + cur;
}
cur = res;
}
void visit(const GiNaC::power& node) override {
node.op(0).accept(*this);
Expr a = cur;
node.op(1).accept(*this);
auto* intv = cur.As<IntImm>();
CHECK(intv);
CHECK_EQ(intv->value, -1);
cur = Div::Make(Expr(1), a);
}
void visit(const GiNaC::mul& node) override {
node.op(0).accept(*this);
Expr res = cur;
for (int i = 1; i < node.nops(); i++) {
node.op(i).accept(*this);
res = res * cur;
}
cur = res;
}
void visit(const GiNaC::basic& basic) override { CINN_NOT_IMPLEMENTED }
};
Expr ExprToGinacConverter::GinacToExpr(const GiNaC::ex& ex) {
GiNaCToExprVisitor visitor(repr_to_expr_);
return visitor(ex);
}
bool IsPureMath(Expr expr) {
std::set<IrNodeTy> valid_node_tys({
IrNodeTy ::_Var_,
IrNodeTy ::IntImm,
IrNodeTy ::Sum,
IrNodeTy ::Product,
IrNodeTy ::FracOp,
IrNodeTy ::FloatImm,
IrNodeTy ::Add,
IrNodeTy ::Sub,
IrNodeTy ::Div,
IrNodeTy ::Mul,
IrNodeTy::Mod,
IrNodeTy ::Minus,
});
auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) {
return !valid_node_tys.count(n->node_type());
});
#ifdef CINN_DEBUG
for (auto& node : complex_nodes) {
VLOG(3) << "Found " << node->node_type() << " " << Expr(node);
}
#endif
return complex_nodes.empty();
}
bool MathContainsSymbol(Expr expr, Var symbol) {
// Use diff(expr, x) and check the result is not zero.
ExprToGinacConverter expr_converter;
auto expr_ex = expr_converter(expr);
if (!expr_converter.HasSymbol(symbol->name)) return false;
return !ginac::diff(expr_ex, expr_converter.GetSymbol(symbol->name))
.is_zero();
}
// lhs >= rhs.
std::tuple<Expr, bool /*positive*/> Solve(Expr lhs, Expr rhs, Var var) {
static std::mutex ginac_mutex;
std::lock_guard<std::mutex> guard(ginac_mutex);
VLOG(4) << "Solve: " << lhs << "=" << rhs << " in " << var;
ExprToGinacConverter converter;
auto lhs_ex = converter(lhs);
auto rhs_ex = converter(rhs);
ginac::lst eqs{lhs_ex == rhs_ex};
VLOG(4) << "eqs: " << eqs;
const auto& symbol = converter.GetSymbol(var->name);
ginac::lst vars{symbol};
ginac::ex res = ginac::lsolve(eqs, vars);
CHECK_EQ(res.nops(), 1);
auto item = res.op(0);
CHECK_EQ(item.nops(), 2);
Expr value = converter.GinacToExpr(item.op(1));
// tell the symbol
auto diff = lhs_ex - rhs_ex;
auto diff_res = ginac::diff(diff, symbol);
CHECK(!diff_res.is_zero());
return std::make_tuple(value, diff_res > 0);
}
bool MathIsZero(Expr expr) {
if (!IsPureMath(expr)) return false;
ExprToGinacConverter converter;
auto ex = converter(expr);
return ex.is_zero();
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file includes some arithmatic utilities, such as simplifying/solving a
* math equation/CINN expression.
*/
#pragma once
#include "paddle/cinn/ir/ir.h"
#include <ginac/ginac.h>
#include <limits>
#include <map>
#include <set>
#include <string>
#include <tuple>
#ifdef As
#undef As
#endif
namespace cinn {
namespace common {
namespace ginac = GiNaC;
//! Tell whether the expression \p expr contains only simple math calculations,
//! like i*32+j is true, while Load(buf, i)+1 is not due to the Load Node is not
//! math related.
bool IsPureMath(Expr expr);
//! Tell whether the expression \p expr contains the expression \symbol, e.g.
//! i*32+32 contains `i`, it also contains `i+1`.
bool MathContainsSymbol(Expr expr, Var symbol);
//! Solve the equation \p lhs == \p rhs on symbol \p symbol.
std::tuple<Expr, bool /*positive*/> Solve(Expr lhs, Expr rhs, Var symbol);
//! Determine whether this expression \p expr calculates to be a zero.
bool MathIsZero(Expr expr);
int gcd(int a, int b);
/**
* Helper to convert cinn::Expr to GiNaC::expr for some symbolic math analysis.
*/
struct ExprToGinacConverter {
//! Convert CINN expression \p expr to GiNaC ex.
ginac::ex operator()(Expr expr);
//! Convert GiNaC ex back to CINN expression, should call operator() first.
Expr GinacToExpr(const GiNaC::ex& ex);
bool HasSymbol(const std::string& name) const {
return repr_to_ginac_.count(name);
}
const ginac::symbol& GetSymbol(const std::string& name) const {
return repr_to_ginac_.at(name);
}
private:
std::string Repr(const Expr& expr);
ginac::symbol CreateGinacSymbol(const std::string& repr);
ginac::symbol CreateGinacSymbol(const ir::Expr& var);
ginac::ex BuildHelper(ir::Expr expr);
void RecordExpr(const ir::Expr& expr);
private:
std::map<std::string, ir::Expr> repr_to_expr_;
std::map<std::string, ginac::symbol> repr_to_ginac_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/arithmatic.h"
#include <ginac/ginac.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace common {
using utils::GetStreamCnt;
using utils::Join;
using utils::Trim;
using namespace ir; // NOLINT
TEST(GiNaC, simplify) {
using namespace GiNaC; // NOLINT
symbol x("x");
symbol y("y");
ex e = x * 0 + 1 + 2 + 3 - 100 + 30 * y - y * 21 + 0 * x;
LOG(INFO) << "e: " << e;
}
TEST(GiNaC, diff) {
using namespace GiNaC; // NOLINT
symbol x("x"), y("y");
ex e = (x + 1);
ex e1 = (y + 1);
e = diff(e, x);
e1 = diff(e1, x);
LOG(INFO) << "e: " << eval(e);
LOG(INFO) << "e1: " << eval(e1);
}
TEST(GiNaC, solve) {
using namespace GiNaC; // NOLINT
symbol x("x"), y("y");
lst eqns{2 * x + 3 == 19};
lst vars{x};
LOG(INFO) << "solve: " << lsolve(eqns, vars);
LOG(INFO) << diff(2 * x + 3, x);
}
TEST(Solve, basic) {
Var i("i", Int(32));
Expr lhs = Expr(i) * 2;
Expr rhs = Expr(2) * Expr(200);
Expr res;
bool is_positive;
std::tie(res, is_positive) = Solve(lhs, rhs, i);
LOG(INFO) << "res: " << res;
EXPECT_TRUE(is_positive);
EXPECT_TRUE(res == Expr(200));
}
TEST(Solve, basic1) {
Var i("i", Int(32));
Expr lhs = Expr(i) * 2;
Expr rhs = Expr(2) * Expr(200) + 3 * Expr(i);
Expr res;
bool is_positive;
std::tie(res, is_positive) = Solve(lhs, rhs, i);
LOG(INFO) << "res " << res;
EXPECT_TRUE(res == Expr(-400));
EXPECT_FALSE(is_positive);
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/axis.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/poly/dim.h"
#include "paddle/cinn/poly/domain.h"
#include "paddle/cinn/poly/stage.h"
namespace cinn {
namespace common {
static const std::vector<std::string> kAxises({
"i", // level 0
"j", // level 1
"k", // level 2
"a", // level 3
"b", // level 4
"c", // level 5
"d", // level 6
"e", // level 7
"f", // level 8
"g", // level 9
"h", // level 10
"l", // level 11
"m", // level 12
"n", // level 13
"o", // level 14
"p", // level 15
"q", // level 16
"r", // level 17
"s", // level 18
"t", // level 19
"u", // level 20
"v" // level 21
});
std::string axis_name(int level) {
if (level < kAxises.size()) {
return kAxises[level];
}
// upper level
int repeat_num = 1 + (level / kAxises.size());
const auto& base_axis = kAxises[level % kAxises.size()];
// if the level greater than kAxis, repeat the axis, like:
// level == 22 ==> axis = "ii"
std::string axis;
for (int i = 0; i < repeat_num; ++i) {
axis.append(base_axis);
}
return axis;
}
std::vector<ir::Var> GenDefaultAxis(int naxis) {
std::vector<ir::Var> axis;
for (int i = 0; i < naxis; i++) {
axis.emplace_back(common::axis_name(i));
CHECK(axis.back()->type().valid());
}
return axis;
}
std::vector<ir::Expr> GenDefaultAxisAsExpr(int naxis) {
auto vars = GenDefaultAxis(naxis);
std::vector<Expr> res;
for (auto& v : vars) {
res.push_back(Expr(v));
}
return res;
}
static const std::set<std::string>& axis_set() {
static std::set<std::string> x(kAxises.begin(), kAxises.end());
return x;
}
bool IsAxisNameReserved(const std::string& x) {
if (x.empty()) {
// axis should not be empty
return false;
}
if (axis_set().count(x)) {
return true;
}
if (!axis_set().count(std::string(1, x[0]))) {
// all char in axis should in kAxises
return false;
}
bool is_repeat_axis = true;
for (int i = 1; i < x.size(); ++i) {
if (x[i] != x[0]) {
// the axis are repeat with the char in kAxises
is_repeat_axis = false;
break;
}
}
return is_repeat_axis;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
namespace cinn {
namespace ir {
struct Var;
struct Expr;
} // namespace ir
} // namespace cinn
namespace cinn {
namespace common {
//! Get the predifined axis name.
std::string axis_name(int level);
//! Generate `naxis` axis using the global names (i,j,k...).
std::vector<ir::Var> GenDefaultAxis(int naxis);
std::vector<ir::Expr> GenDefaultAxisAsExpr(int naxis);
bool IsAxisNameReserved(const std::string& x);
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/axis.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace common {
TEST(AXISNAME, BASE) {
ASSERT_EQ(axis_name(0), std::string("i"));
ASSERT_EQ(axis_name(1), std::string("j"));
ASSERT_EQ(axis_name(22), std::string("ii"));
ASSERT_EQ(axis_name(44), std::string("iii"));
}
TEST(AXISNAME, CHECK_RESERVED) {
ASSERT_TRUE(IsAxisNameReserved("i"));
ASSERT_TRUE(IsAxisNameReserved("j"));
ASSERT_TRUE(IsAxisNameReserved("ii"));
ASSERT_TRUE(IsAxisNameReserved("iiiiiiiiii"));
ASSERT_FALSE(IsAxisNameReserved("ijk"));
ASSERT_FALSE(IsAxisNameReserved("iiiiiiiiiij"));
ASSERT_FALSE(IsAxisNameReserved("x"));
}
} // namespace common
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CINN_COMMON_BFLOAT16_H
#define CINN_COMMON_BFLOAT16_H
#ifdef __cplusplus
#pragma once
#endif // __cplusplus
#include <stdint.h>
#include <cmath>
#include <cstring>
#ifdef CINN_WITH_CUDA
#include <cuda.h>
#if (defined(__CUDACC__) || defined(__CUDACC_RTC__)) && CUDA_VERSION >= 11000
#define CINN_CUDA_BF16
#include <cuda_bf16.h>
#endif // __CUDACC__
#endif // CINN_WITH_CUDA
#ifdef __cplusplus
#ifndef _WIN32
#define CINN_ALIGN(x) __attribute__((aligned(x)))
#else // _WIN32
#define CINN_ALIGN(x) __declspec(align(x))
#endif // _WIN32
#else // __cplusplus
#define CINN_ALIGN(x)
#endif // __cplusplus
// The `HOST` macro definition is not used here, it has a potential
// conflict with the enumeration `kHOST` representing the backend.
#ifndef __host__
#define __host__
#endif
#ifndef __device__
#define __device__
#endif
#ifdef __cplusplus
namespace cinn {
namespace common {
#endif // __cplusplus
// Use CINN_ALIGNED(2) to ensure that each bfloat16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes bfloat16 compatible
// with CUDA half
struct CINN_ALIGN(2) bfloat16 {
uint16_t x;
#ifdef __cplusplus
// Constructors
bfloat16() = default;
bfloat16(const bfloat16& o) = default;
bfloat16& operator=(const bfloat16& o) = default;
bfloat16(bfloat16&& o) = default;
bfloat16& operator=(bfloat16&& o) = default;
~bfloat16() = default;
__host__ __device__ inline explicit bfloat16(float val) {
#if defined(CINN_CUDA_BF16)
__nv_bfloat16 tmp = __float2bfloat16(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#else
std::memcpy(&x, reinterpret_cast<char*>(&val) + 2, 2);
#endif
}
#if defined(CINN_CUDA_BF16)
__host__ __device__ inline explicit bfloat16(const __nv_bfloat16& val) {
x = *reinterpret_cast<const unsigned short*>(&val); // NOLINT
}
#endif
template <class T>
__host__ __device__ inline explicit bfloat16(const T& val)
: x(bfloat16(static_cast<float>(val)).x) {}
// Assignment operators
#if defined(CINN_CUDA_BF16)
__host__ __device__ inline bfloat16& operator=(const __nv_bfloat16& val) {
x = *reinterpret_cast<const unsigned short*>(&val); // NOLINT
return *this;
}
#endif
__host__ __device__ inline bfloat16& operator=(bool b) {
x = b ? 0x3f80 : 0;
return *this;
}
__host__ __device__ inline bfloat16& operator=(int8_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(uint8_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(int16_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(uint16_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(int32_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(uint32_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(int64_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(uint64_t val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(float val) {
x = bfloat16(val).x;
return *this;
}
__host__ __device__ inline bfloat16& operator=(double val) {
x = bfloat16(val).x;
return *this;
}
// Conversion opertors
__host__ __device__ inline operator float() const {
#ifdef CINN_CUDA_BF16
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#else
float val = 0.f;
uint16_t temp = x;
std::memcpy(
reinterpret_cast<char*>(&val) + 2, reinterpret_cast<char*>(&temp), 2);
return val;
#endif
}
#ifdef CINN_CUDA_BF16
__host__ __device__ inline __nv_bfloat16 to_nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
__host__ __device__ inline explicit operator bool() const {
return (x & 0x7fff) != 0;
}
__host__ __device__ inline explicit operator int8_t() const {
return static_cast<int8_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}
__host__ __device__ inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}
__host__ __device__ inline operator double() const {
return static_cast<double>(static_cast<float>(*this));
}
#endif // __cplusplus
};
__host__ __device__ inline bfloat16 operator+(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return bfloat16(__hadd(a.to_nv_bfloat16(), b.to_nv_bfloat16()));
#else
return bfloat16(static_cast<float>(a) + static_cast<float>(b));
#endif
}
__host__ __device__ inline bfloat16 operator-(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return bfloat16(__hsub(a.to_nv_bfloat16(), b.to_nv_bfloat16()));
#else
return bfloat16(static_cast<float>(a) - static_cast<float>(b));
#endif
}
__host__ __device__ inline bfloat16 operator*(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return bfloat16(__hmul(a.to_nv_bfloat16(), b.to_nv_bfloat16()));
#else
return bfloat16(static_cast<float>(a) * static_cast<float>(b));
#endif
}
__host__ __device__ inline bfloat16 operator/(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return bfloat16(__hdiv(a.to_nv_bfloat16(), b.to_nv_bfloat16()));
#else
return bfloat16(static_cast<float>(a) / static_cast<float>(b));
#endif
}
__host__ __device__ inline bfloat16 operator-(const bfloat16& a) {
bfloat16 res;
res.x = a.x ^ 0x8000;
return res;
}
__host__ __device__ inline bfloat16& operator+=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = a + b;
return a;
}
__host__ __device__ inline bfloat16& operator-=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = a - b;
return a;
}
__host__ __device__ inline bfloat16& operator*=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = a * b;
return a;
}
__host__ __device__ inline bfloat16& operator/=(bfloat16& a, // NOLINT
const bfloat16& b) {
a = a / b;
return a;
}
__host__ __device__ inline bfloat16 raw_uint16_to_bfloat16(uint16_t a) {
bfloat16 res;
res.x = a;
return res;
}
// Comparison operators
__host__ __device__ inline bool operator==(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __heq(a.to_nv_bfloat16(), b.to_nv_bfloat16());
#else
return static_cast<float>(a) == static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator!=(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hne(a.to_nv_bfloat16(), b.to_nv_bfloat16());
#else
return static_cast<float>(a) != static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator<(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hlt(a.to_nv_bfloat16(), b.to_nv_bfloat16());
#else
return static_cast<float>(a) < static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator<=(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hle(a.to_nv_bfloat16(), b.to_nv_bfloat16());
#else
return static_cast<float>(a) <= static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator>(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hgt(a.to_nv_bfloat16(), b.to_nv_bfloat16());
#else
return static_cast<float>(a) > static_cast<float>(b);
#endif
}
__host__ __device__ inline bool operator>=(const bfloat16& a,
const bfloat16& b) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hge(a.to_nv_bfloat16(), b.to_nv_bfloat16());
#else
return static_cast<float>(a) >= static_cast<float>(b);
#endif
}
__host__ __device__ inline bool(isnan)(const bfloat16& a) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hisnan(a.to_nv_bfloat16());
#else
return (a.x & 0x7FFF) > 0x7F80;
#endif
}
__host__ __device__ inline bool(isinf)(const bfloat16& a) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hisinf(a.to_nv_bfloat16());
#else
return (a.x & 0x7F80) == 0x7F80;
#endif
}
__host__ __device__ inline bool(isfinite)(const bfloat16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
__host__ __device__ inline bfloat16(abs)(const bfloat16& a) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return bfloat16(__habs(a.to_nv_bfloat16()));
#else
return bfloat16(std::abs(static_cast<float>(a)));
#endif
}
#ifdef __cplusplus
} // namespace common
} // namespace cinn
#endif // __cplusplus
// for runtime calls
#if defined(__cplusplus) && defined(CINN_CUDA_BF16)
__device__ inline cinn::common::bfloat16 __shfl_sync(unsigned mask,
cinn::common::bfloat16 var,
int srcLane,
int width = warpSize) {
return cinn::common::bfloat16(
__shfl_sync(mask, var.to_nv_bfloat16(), srcLane, width));
}
__device__ inline cinn::common::bfloat16 __shfl_up_sync(
unsigned mask,
cinn::common::bfloat16 var,
unsigned int delta,
int width = warpSize) {
return cinn::common::bfloat16(
__shfl_up_sync(mask, var.to_nv_bfloat16(), delta, width));
}
__device__ inline cinn::common::bfloat16 __shfl_down_sync(
unsigned mask,
cinn::common::bfloat16 var,
unsigned int delta,
int width = warpSize) {
return cinn::common::bfloat16(
__shfl_down_sync(mask, var.to_nv_bfloat16(), delta, width));
}
__device__ inline cinn::common::bfloat16 __shfl_xor_sync(
unsigned mask,
cinn::common::bfloat16 var,
int laneMask,
int width = warpSize) {
return cinn::common::bfloat16(
__shfl_xor_sync(mask, var.to_nv_bfloat16(), laneMask, width));
}
__host__ __device__ inline cinn::common::bfloat16 max(
const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) {
return a > b ? a : b;
}
__host__ __device__ inline cinn::common::bfloat16 min(
const cinn::common::bfloat16& a, const cinn::common::bfloat16& b) {
return a < b ? a : b;
}
#endif // __cplusplus && CINN_CUDA_FP16
#endif // CINN_COMMON_BFLOAT16_H
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <functional>
#include <queue>
#include <unordered_set>
namespace cinn {
namespace common {
// breadth-first search visitor
template <typename NodeType>
class BfsWalker final {
public:
BfsWalker(const BfsWalker&) = delete;
BfsWalker(BfsWalker&&) = delete;
using NodeHandlerType = std::function<void(NodeType)>;
using NodesVisitorType =
std::function<void(NodeType, const NodeHandlerType&)>;
BfsWalker(const NodesVisitorType& VisitNextNodes)
: VisitNextNodes_(VisitNextNodes) {}
void operator()(NodeType node, const NodeHandlerType& NodeHandler) const {
std::array<NodeType, 1> nodes{node};
(*this)(nodes.begin(), nodes.end(), NodeHandler);
}
template <typename NodeIt>
void operator()(NodeIt begin,
NodeIt end,
const NodeHandlerType& NodeHandler) const {
std::queue<NodeType> node_queue;
std::unordered_set<NodeType> queued_nodes;
const auto& TryEnqueueNode = [&](NodeType node) {
if (queued_nodes.count(node) == 0) {
node_queue.push(node);
queued_nodes.insert(node);
}
};
for (NodeIt iter = begin; iter != end; ++iter) {
TryEnqueueNode(*iter);
}
while (!node_queue.empty()) {
NodeType node = node_queue.front();
node_queue.pop();
NodeHandler(node);
VisitNextNodes_(node, TryEnqueueNode);
}
}
private:
NodesVisitorType VisitNextNodes_;
};
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/cas.h"
#include <algorithm>
#include <cmath>
#include <string>
#include <utility>
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace common {
using namespace ir; // NOLINT
Expr AutoSimplify(
Expr u,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
VLOG(7) << "Begin AutoSimplify: " << u;
u = detail::ConvertCinnToCAS(u);
absl::flat_hash_map<std::string, CasInterval> s_var_intervals;
for (auto& item : var_intervals) {
if (item.second.e_l.defined() && item.second.e_r.defined()) {
Expr e_l = detail::ConvertCinnToCAS(item.second.e_l);
Expr e_r = detail::ConvertCinnToCAS(item.second.e_r);
s_var_intervals.emplace(item.first, CasInterval(e_l, e_r));
} else {
s_var_intervals.emplace(item.first,
CasInterval(item.second.l, item.second.r));
}
}
u = CasSimplify(u, s_var_intervals);
u = detail::ConvertCasToCinn(u);
VLOG(7) << "End AutoSimplify " << u;
return u;
}
int gcd(int a, int b) {
// Everything divides 0
if (a == 0) return b;
if (b == 0) return a;
if (a == 1 || b == 1) return 1;
if (a < 0 || b < 0) {
return gcd(std::abs(a), std::abs(b));
}
// base case
if (a == b) return a;
// a is greater
if (a > b) return gcd(a - b, b);
return gcd(a, b - a);
}
//////// All the following symbolic computation methods are implemented
/// referencing to the book <Computer Algegra and
/// Symbolic Computation - Joel S. Cohen>
template <typename T>
std::vector<T> EraseFront(const std::vector<T>& vs) {
return std::vector<T>(vs.begin() + 1, vs.end());
}
template <typename T>
std::vector<T> Concat(const std::vector<T>& as, const std::vector<T>& bs) {
auto res = as;
res.insert(std::end(res), bs.begin(), bs.end());
return res;
}
// 3*x*2*y => 3*2
// x => 1
Expr ProductGetConstantPart(Expr u) {
auto* product = u.As<Product>();
if (product) {
std::vector<Expr> constant_operands;
for (auto& i : product->operands()) {
if (i.is_constant()) {
constant_operands.push_back(i);
}
}
if (constant_operands.empty())
return make_const(u->type(), 1);
else if (constant_operands.size() == 1)
return constant_operands.front();
else
return Product::Make(constant_operands);
}
return make_const(u->type(), 1);
}
// 3*x*2*y => x*y
// x => x
Expr ProductGetNonConstantPart(Expr u) {
auto* product = u.As<Product>();
if (product) {
std::vector<Expr> nonconstant_operands;
for (auto& i : product->operands()) {
if (!i.is_constant()) {
nonconstant_operands.push_back(i);
}
}
if (nonconstant_operands.empty()) {
return make_const(u->type(), 1);
} else if (nonconstant_operands.size() == 1) {
return nonconstant_operands.front();
} else {
return Product::Make(nonconstant_operands);
}
}
return u;
}
namespace detail {
// Is a Divisible to b.
// @{
bool IsDivisible(int64_t a, int64_t b) {
CHECK_NE(b, 0);
return a % b == 0;
}
bool IsDivisible(const Sum* a, int b);
// If int a Divisible to any operands of product b
bool IsDivisible(int a, const Product* b) {
if (a < 0) return false;
for (auto& item : b->operands()) {
if (item.As<IntImm>() && item.As<IntImm>()->value > 0 &&
IsDivisible(a, item.As<IntImm>()->value))
return true;
}
return false;
}
bool IsDivisible(const Product* a, int b) {
for (auto& item : a->operands()) {
if (item.As<IntImm>() && IsDivisible(item.As<IntImm>()->value, b)) {
return true;
}
if (item.As<Sum>() && IsDivisible(item.As<Sum>(), b)) return true;
}
return false;
}
bool IsDivisible(const Sum* a, int b) {
for (auto& item : a->operands()) {
auto* vi = item.As<IntImm>();
auto* vp = item.As<Product>();
if (vi && IsDivisible(vi->value, b)) continue;
if (vp && IsDivisible(vp, b)) continue;
return false;
}
return true;
}
bool IsDivisible(Expr a, int b) {
auto* ai = a.As<IntImm>();
auto* as = a.As<Sum>();
auto* ap = a.As<Product>();
if (ai) return IsDivisible(ai->value, b);
if (as) return IsDivisible(as, b);
if (ap) return IsDivisible(ap, b);
return false;
}
// @}
//! Divide a by b, NOTE that a should be divisible by b.
// @{
Expr Divide(const Product* a, int b);
Expr Divide(const Sum* a, int b) {
std::vector<Expr> args;
for (auto& item : a->operands()) {
if (item.As<IntImm>())
args.push_back(make_const(item.type(), item.As<IntImm>()->value / b));
else if (item.As<Product>())
args.push_back(Divide(item.As<Product>(), b));
else
CINN_NOT_IMPLEMENTED
}
return Sum::Make(args);
}
Expr Divide(const Product* a, int b) {
std::vector<Expr> args;
int i = 0;
int times = -1;
bool is_divisible = false;
for (i = 0; i < a->operands().size(); i++) {
auto* a_i = a->operand(i).As<IntImm>();
if (a_i && a_i->value % b == 0) {
times = a_i->value / b;
is_divisible = true;
break;
}
}
// Case is_divisible : a = 8x and b = 4 and a/b = 2x
// Case !is_divisible : a = 2x and b = 8 and a/b = x/4
if (is_divisible) {
// NOTE that a should be divisible by b.
if (times != 1) {
args.push_back(make_const(a->type(), times));
}
for (int j = 0; j < a->operands().size(); j++) {
if (j == i) continue;
args.push_back(a->operand(j));
}
return Product::Make(args);
} else {
for (i = 0; i < a->operands().size(); i++) {
auto* a_i = a->operand(i).As<IntImm>();
if (a_i && b % a_i->value == 0) {
b = b / a_i->value;
} else {
args.push_back(a->operand(i));
}
}
return FracOp::Make(Product::Make(args), Expr(b));
}
return Product::Make(args);
}
// @}
inline int Iquot(int n, int d) { return n / d; }
inline int Irem(int n, int d) {
int k = Iquot(n, d);
return n - d * k;
}
Expr CasSimplifyMutator::SimplifyRationalNumber(Expr u) {
auto* frac_n = u.As<FracOp>();
if (frac_n) {
Expr n = frac_n->a();
Expr d = frac_n->b();
auto* ni = n.As<IntImm>();
auto* di = d.As<IntImm>();
CHECK(ni && di);
int nv = ni->value;
int dv = di->value;
if (Irem(nv, dv) == 0) {
return Expr(make_const(u.type(), Iquot(nv, dv)));
} else {
int g = gcd(nv, dv);
if (dv > 0) {
return FracOp::Make(make_const(Iquot(nv, g)), make_const(Iquot(dv, g)));
} else {
return FracOp::Make(make_const(Iquot(-nv, g)),
make_const(Iquot(-dv, g)));
}
}
}
return u;
}
Expr SumOrProductGetSingleElementsRec(Expr u) {
auto* product = u.As<Product>();
auto* sum = u.As<Sum>();
if (product && product->operands().size() == 1) {
return SumOrProductGetSingleElementsRec(u->operands.front());
}
if (sum && sum->operands().size() == 1) {
return SumOrProductGetSingleElementsRec(u->operands.front());
}
return u;
}
// Order, reference to Page 85.
bool ExprPosCmp::operator()(const Expr& a, const Expr& b) {
// O-1, 1 <| 2
VLOG(7) << "Begin ExprPosCmp, a: " << a << ", b: " << b;
if (a.is_constant() && b.is_constant()) {
return a.get_constant() < b.get_constant();
}
// O-2, both are symbols, compare by the lexicographical order.
if (a.As<_Var_>() && b.As<_Var_>()) {
return a.As<_Var_>()->name < b.As<_Var_>()->name;
}
// O-3, if a and b are either both products or both sums, compare by each
// element similar to lexicographical order.
if ((a.As<Product>() && b.As<Product>()) || (a.As<Add>() && b.As<Add>())) {
auto& aoprs = a->operands;
auto& boprs = b->operands;
int m = std::min(aoprs.size(), boprs.size());
for (int i = 0; i < m; i++) {
// ugly compare representation in string.
auto& aopr = aoprs[aoprs.size() - 1 - i];
auto& bopr = boprs[boprs.size() - 1 - i];
if (aopr != bopr) return operator()(aopr, bopr);
}
return aoprs.size() < boprs.size();
}
// customized case, if both are mod
{
auto* am = a.As<Mod>();
auto* bm = b.As<Mod>();
if (am && bm) {
if (am->b() != bm->b()) {
return operator()(am->b(), bm->b());
}
return operator()(am->a(), bm->a());
}
}
// O-7, if a is an integer or fraction and v is any other type, 1 < x
if (a.As<IntImm>() || a.As<FloatImm>() || a.As<FracOp>()) {
if (!(b.As<IntImm>() || b.As<FloatImm>() || b.As<FracOp>())) return true;
}
if (b.As<IntImm>() || b.As<FloatImm>() || b.As<FracOp>()) {
if (!(a.As<IntImm>() || a.As<FloatImm>() || a.As<FracOp>())) return false;
}
// O-8, if a is a product, v is a sum, fractional, or symbol
{
auto* ap = a.As<Product>();
if (ap && (b.As<Sum>() || b.As<Call>() || b.As<_Var_>() || b.As<Mod>())) {
return operator()(a, Product::Make({b}));
}
}
{
if (a.As<Mod>()) {
if (!b.As<Mod>()) {
// Todo: may be wrong especially for negative value
return operator()(a, Mod::Make(b, Sum::Make({b, Expr(1)})));
}
}
}
// O-10, if a is a sum, b is a function, or symbol
{
if (a.As<Sum>()) {
if (b.As<_Var_>()) {
return operator()(a.As<Sum>()->operand(0), {b});
}
}
}
return false;
}
std::vector<Expr> CasSimplifyMutator::MergeProduct(const std::vector<Expr>& p,
const std::vector<Expr>& q) {
return MergeExprs(p,
q,
std::bind(&CasSimplifyMutator::SimplifyBinaryProduct,
this,
std::placeholders::_1,
std::placeholders::_2));
}
std::vector<Expr> CasSimplifyMutator::SimplifyBinaryProduct(Expr left,
Expr right) {
// SPRDREC-1
if (!left.As<Product>() && !right.As<Product>()) {
auto a = left;
auto b = right;
auto* ai = a.As<IntImm>();
auto* af = a.As<FloatImm>();
auto* bi = b.As<IntImm>();
auto* bf = b.As<FloatImm>();
// case 1, both are constants
if (a.is_constant() && b.is_constant()) {
if (ai) return {make_const(a.type(), ai->value * bi->value)};
if (af) return {make_const(a.type(), af->value * bf->value)};
}
if (a.As<Max>() || a.As<Min>() || b.As<Max>() || b.As<Min>()) {
// cinn_min/cinn_max(a, b) * 2 = cinn_min/cinn_max(2*a, 2*b)
// 2 * cinn_min/cinn_max(a, b) = cinn_min/cinn_max(2*a, 2*b)
// cinn_min/cinn_max(a, b) * -2 = cinn_max/cinn_min(-2*b, -2*a)
// -2 * cinn_min/cinn_max(a, b) = cinn_max/cinn_min(-2*b, -2*a)
Expr const_oper;
Expr cmp_oper;
int const_value;
if (ai) {
const_oper = a;
cmp_oper = b;
const_value = ai->value;
}
if (af) {
const_oper = a;
cmp_oper = b;
const_value = af->value;
}
if (bi) {
const_oper = b;
cmp_oper = a;
const_value = bi->value;
}
if (bf) {
const_oper = b;
cmp_oper = a;
const_value = bf->value;
}
if (const_value == 0) {
return {make_const(a->type(), 0)};
}
if (cmp_oper.defined() && const_oper.defined()) {
auto cmp_min = cmp_oper.As<Min>();
auto cmp_max = cmp_oper.As<Max>();
if (const_value > 0) {
if (cmp_min) {
return {CasSimplify(
Min::Make(CasSimplify(Product::Make({cmp_min->a(), const_oper}),
var_intervals),
CasSimplify(Product::Make({cmp_min->b(), const_oper}),
var_intervals)),
var_intervals)};
}
if (cmp_max) {
return {CasSimplify(
Max::Make(CasSimplify(Product::Make({cmp_max->a(), const_oper}),
var_intervals),
CasSimplify(Product::Make({cmp_max->b(), const_oper}),
var_intervals)),
var_intervals)};
}
} else {
if (cmp_min) {
return {CasSimplify(
Max::Make(CasSimplify(Product::Make({cmp_min->b(), const_oper}),
var_intervals),
CasSimplify(Product::Make({cmp_min->a(), const_oper}),
var_intervals)),
var_intervals)};
}
if (cmp_max) {
return {CasSimplify(
Min::Make(CasSimplify(Product::Make({cmp_max->b(), const_oper}),
var_intervals),
CasSimplify(Product::Make({cmp_max->a(), const_oper}),
var_intervals)),
var_intervals)};
}
}
}
}
{ // FracOp related constants.
// NOTE the integer division is weried in C language, 1/2 = 0, that is
// huge different from a real CAS.
auto* af = a.As<FracOp>();
auto* bf = b.As<FracOp>();
// 1/2 * 2/3
if (af && bf && a->type().is_float()) {
return {CasSimplify(FracOp::Make(Product::Make({af->a(), bf->a()}),
Product::Make({af->b(), bf->b()})),
var_intervals)};
}
if (af && !bf && a->type().is_float()) {
return {CasSimplify(FracOp::Make(Product::Make({af->a(), b}), af->b()),
var_intervals)};
}
if (!af && bf && a->type().is_float()) {
return {CasSimplify(FracOp::Make(Product::Make({bf->a(), a}), bf->b()),
var_intervals)};
}
}
// case 2
// x*1 -> a
if (ai && ai->value == 1) return {b};
if (af && af->value == 1.f) return {b};
// 1*x -> x
if (bi && bi->value == 1) return {a};
if (bf && bf->value == 1.f) return {a};
{
auto* a_sum = a.As<Sum>();
auto* b_sum = b.As<Sum>();
if (b_sum) {
std::vector<Expr> args;
for (auto& v : b_sum->operands()) {
args.push_back(CasSimplify(Product::Make({a, v}), var_intervals));
}
return {SimplifySum(Sum::Make(args))};
}
if (a_sum) {
std::vector<Expr> args;
for (auto& v : a_sum->operands()) {
args.push_back(CasSimplify(Product::Make({b, v}), var_intervals));
}
return {SimplifySum(Sum::Make(args))};
}
}
// case 4, b <| a
{
if (ExprPosCmp()(b, a)) {
return {b, a};
}
}
return {left, right};
}
// SPRDREC-2, Page 101
if (left.As<Product>() || right.As<Product>()) {
auto a = left;
auto b = right;
auto* a_product = a.As<Product>();
auto* b_product = b.As<Product>();
// case 1
if (a_product && b_product) {
return MergeProduct(a_product->operands(), b_product->operands());
}
// case 2
if (a_product) {
return MergeProduct(a_product->operands(), {b});
}
// case 3
if (b_product) {
return MergeProduct({a}, b_product->operands());
}
}
return {left, right};
}
std::vector<Expr> CasSimplifyMutator::SimplifyProductRec(
const std::vector<Expr>& operands) {
if (operands.size() < 2)
return {CasSimplify(operands.front(), var_intervals)};
auto mid_it = operands.begin() + operands.size() / 2;
auto&& left = SimplifyProductRec(std::vector<Expr>(operands.begin(), mid_it));
auto&& right = SimplifyProductRec(std::vector<Expr>(mid_it, operands.end()));
return MergeProduct(left, right);
}
Expr CasSimplifyMutator::SimplifyProduct(Expr a) {
a = SumOrProductGetSingleElementsRec(a);
// We reuse the Mul node for production.
auto* prod = a.As<Product>();
if (!prod) return a;
const auto& _operands = prod->operands();
std::vector<Expr> operands;
for (auto& e : _operands) operands.push_back(CasSimplify(e, var_intervals));
#ifdef CINN_DEBUG
{
std::stringstream ss;
for (auto& v : operands) {
ss << v << " ";
}
VLOG(7) << "operands: " << ss.str();
};
#endif
// SPRD-2
// 0*x... = 0
for (auto& opr : operands) {
auto* opri = opr.As<IntImm>();
auto* oprf = opr.As<FloatImm>();
if (opri && opri->value == 0) return make_const(a.type(), 0);
if (oprf && oprf->value == 0) return make_const(a.type(), 0);
}
// SPRD-3
// prod(x) = x, single number.
if (operands.size() == 1) {
auto* first_s = operands.front().As<Sum>();
auto* first_p = operands.front().As<Product>();
return operands[0];
}
// SPRD-4
return Product::Make(SimplifyProductRec(operands));
}
Expr CasSimplifyMutator::SimplifySum(Expr u) {
u = SumOrProductGetSingleElementsRec(u);
auto* sum = u.As<Sum>();
CHECK(sum);
auto& operands = sum->operands();
auto temp = SimplifySpecificSum(u);
// If temp has been simplified, return it.
if (!temp.As<Sum>()) return temp;
operands = temp.As<Sum>()->operands();
auto args = SimplifySumRec(operands);
if (args.empty()) return make_const(u.type(), 0);
if (args.size() == 1) return args[0];
return Sum::Make(args);
}
std::vector<Expr> CasSimplifyMutator::MergeExprs(
const std::vector<Expr>& p,
const std::vector<Expr>& q,
const std::function<std::vector<Expr>(Expr, Expr)>& binary_merge) {
std::vector<Expr> res;
int li = 0, lj = 0;
while (li < p.size() && lj < q.size()) {
auto&& p1 = p[li];
auto&& q1 = q[lj];
auto&& h = binary_merge(p1, q1);
if (h.size() == 2 && h[0] == p1 && h[1] == q1) {
++li;
res.emplace_back(std::move(h.front()));
} else if (h.size() == 2 && h[0] == q1 && h[1] == p1) {
++lj;
res.emplace_back(std::move(h.front()));
} else {
++li;
++lj;
std::move(h.begin(), h.end(), std::back_inserter(res));
}
}
if (li < p.size()) res.insert(res.end(), p.begin() + li, p.end());
if (lj < q.size()) res.insert(res.end(), q.begin() + lj, q.end());
return std::move(res);
}
// This implementation is similar to MergeProduct
std::vector<Expr> CasSimplifyMutator::MergeSum(const std::vector<Expr>& p,
const std::vector<Expr>& q) {
#ifdef CINN_DEBUG
{
std::stringstream ss;
for (auto& x : p) ss << x << " ";
VLOG(7) << "MergeSum p(" << ss.str() << ")";
ss.str("");
for (auto& x : q) ss << x << " ";
VLOG(7) << "MergeSum q(" << ss.str() << ")";
ss.str("");
}
#endif
return MergeExprs(p, q, [this](Expr left, Expr right) -> std::vector<Expr> {
auto&& h = SimplifyBinarySum(std::move(left), std::move(right));
if (h.size() == 1 && h[0].is_constant() && h[0].get_constant() == 0) {
return {};
} else {
return std::move(h);
}
});
}
std::vector<Expr> CasSimplifyMutator::SimplifyBinarySum(Expr left, Expr right) {
// SPRDREC-1
if (!left.As<Sum>() && !right.As<Sum>()) {
auto a = left;
auto b = right;
auto* ai = a.As<IntImm>();
auto* af = a.As<FloatImm>();
auto* bi = b.As<IntImm>();
auto* bf = b.As<FloatImm>();
// case 1, both are constants
if (a.is_constant() && b.is_constant()) {
if (ai) return {make_const(a.type(), ai->value + bi->value)};
if (af) return {make_const(a.type(), af->value + bf->value)};
}
// cinn_min/cinn_max(a, b)+c = cinn_min/cinn_max(a+c, b+c)
// c + cinn_min/cinn_max(a, b) = cinn_min/cinn_max(a+c, b+c)
auto* a_min = a.As<Min>();
auto* a_max = a.As<Max>();
auto* b_min = b.As<Min>();
auto* b_max = b.As<Max>();
if (a_min) {
return {CasSimplify(
Min::Make(CasSimplify(Sum::Make({a_min->a(), b}), var_intervals),
CasSimplify(Sum::Make({a_min->b(), b}), var_intervals)),
var_intervals)};
}
if (a_max) {
return {CasSimplify(
Max::Make(CasSimplify(Sum::Make({a_max->a(), b}), var_intervals),
CasSimplify(Sum::Make({a_max->b(), b}), var_intervals)),
var_intervals)};
}
if (b_min) {
return {CasSimplify(
Min::Make(CasSimplify(Sum::Make({b_min->a(), a}), var_intervals),
CasSimplify(Sum::Make({b_min->b(), a}), var_intervals)),
var_intervals)};
}
if (b_max) {
return {CasSimplify(
Max::Make(CasSimplify(Sum::Make({b_max->a(), a}), var_intervals),
CasSimplify(Sum::Make({b_max->b(), a}), var_intervals)),
var_intervals)};
}
// case 2
// x*1 -> a
if (ai && ai->value == 0) return {b};
if (af && af->value == 0.f) return {b};
// 1*x -> x
if (bi && bi->value == 0) return {a};
if (bf && bf->value == 0.f) return {a};
// customized case for Mod
{
auto* am = a.As<Mod>();
auto* bm = b.As<Mod>();
if (am && bm) {
if (am->b() == bm->b() && ProductGetNonConstantPart(am->a()) ==
ProductGetNonConstantPart(bm->a())) {
return {CasSimplify(Mod::Make(Sum::Make({am->a(), bm->a()}), am->b()),
var_intervals)};
}
}
}
// case 3
// Here is different from SimplifySumRec, to deal with cases like 3x + (-2x)
// = 2x
auto a_non_constant = ProductGetNonConstantPart(a);
auto b_non_constant = ProductGetNonConstantPart(b);
if (a_non_constant.defined() && b_non_constant.defined() &&
a_non_constant == b_non_constant) {
VLOG(7) << "a " << a;
VLOG(7) << "b " << b;
Expr s = SimplifySum(
Sum::Make({ProductGetConstantPart(a), ProductGetConstantPart(b)}));
Expr p = Product::Make({s, ProductGetNonConstantPart(a)});
return {CasSimplify(p, var_intervals)};
}
// case 4, b <| a
{
if (ExprPosCmp()(b, a)) {
return {b, a};
}
}
return {left, right};
}
// SPRDREC-2, Page 101
if (left.As<Sum>() || right.As<Sum>()) {
auto a = left;
auto b = right;
auto* a_sum = a.As<Sum>();
auto* b_sum = b.As<Sum>();
// case 1
if (a_sum && b_sum) {
return MergeSum(a_sum->operands(), b_sum->operands());
}
// case 2
if (a_sum) {
return MergeSum(a_sum->operands(), {b});
}
// case 3
if (b_sum) {
return MergeSum({a}, b_sum->operands());
}
}
return {left, right};
}
// The implementation is similar to SimplifyProductRec
std::vector<Expr> CasSimplifyMutator::SimplifySumRec(
const std::vector<Expr>& operands) {
#ifdef CINN_DEBUG
{
std::stringstream ss;
for (auto& o : operands) {
ss << o.node_type() << " " << o << " ";
}
VLOG(7) << "SimplifySumRec operands: " << ss.str();
}
#endif
CHECK(!operands.empty());
if (operands.size() < 2)
return {CasSimplify(operands.front(), var_intervals)};
auto mid_it = operands.begin() + operands.size() / 2;
auto&& left = SimplifySumRec(std::vector<Expr>(operands.begin(), mid_it));
auto&& right = SimplifySumRec(std::vector<Expr>(mid_it, operands.end()));
return MergeSum(left, right);
}
void CasSimplifyMutator::AddBaseAndSimplify(Expr* base, Expr bound) {
if ((*base).defined()) {
*base = Sum::Make({*base, bound});
} else {
*base = bound;
}
*base = CasSimplify(*base, var_intervals);
}
void CasSimplifyMutator::UnfoldBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound) {
CHECK(lower_bound);
CHECK(upper_bound);
auto v_var = var.As<_Var_>();
CHECK(v_var);
if (var_intervals.count(v_var->name)) {
auto& interval = var_intervals.at(v_var->name);
if (interval.e_l.defined() && interval.e_r.defined()) {
AddBaseAndSimplify(lower_bound, interval.e_l);
AddBaseAndSimplify(upper_bound, interval.e_r);
} else if (unfold_const_bound) {
// unfold var's const bound
AddBaseAndSimplify(lower_bound, Expr(interval.l));
AddBaseAndSimplify(upper_bound, Expr(interval.r));
} else {
// no unfold var's const bound for var simplification
AddBaseAndSimplify(lower_bound, var);
AddBaseAndSimplify(upper_bound, var);
}
} else if (!unfold_const_bound) {
// not get var's bound for var simplification
AddBaseAndSimplify(lower_bound, var);
AddBaseAndSimplify(upper_bound, var);
} else {
LOG(FATAL) << "can't get the bound";
}
}
bool CasSimplifyMutator::GetVarBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound) {
CHECK(lower_bound);
CHECK(upper_bound);
auto v_var = var.As<_Var_>();
auto v_product = var.As<Product>();
auto v_frac = var.As<FracOp>();
if (v_var && (var_intervals.count(v_var->name) || !unfold_const_bound)) {
UnfoldBound(lower_bound, upper_bound, var, unfold_const_bound);
return true;
} else if (v_product) {
// only deal with 2*x
Expr p_lower_bound;
Expr p_upper_bound;
Expr const_oper = ProductGetConstantPart(var);
Expr non_const_oper = ProductGetNonConstantPart(var);
auto v_var = non_const_oper.As<_Var_>();
if (v_var && var_intervals.count(v_var->name)) {
Expr v_lower, v_upper;
UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound);
auto const_v = const_oper.get_constant();
CHECK(v_lower.defined() && v_upper.defined());
if (const_v > 0) {
p_lower_bound = Product::Make({const_oper, v_lower});
p_upper_bound = Product::Make({const_oper, v_upper});
} else {
p_lower_bound = Product::Make({const_oper, v_upper});
p_upper_bound = Product::Make({const_oper, v_lower});
}
AddBaseAndSimplify(lower_bound, p_lower_bound);
AddBaseAndSimplify(upper_bound, p_upper_bound);
return true;
}
} else if (v_frac) {
// only deal with x/2
Expr p_lower_bound;
Expr p_upper_bound;
Expr non_const_oper = v_frac->a();
Expr const_oper = v_frac->b();
auto v_var = non_const_oper.As<_Var_>();
if (v_var && var_intervals.count(v_var->name)) {
Expr v_lower, v_upper;
UnfoldBound(&v_lower, &v_upper, non_const_oper, unfold_const_bound);
auto const_v = const_oper.get_constant();
CHECK(v_lower.defined() && v_upper.defined());
if (const_v > 0) {
p_lower_bound = FracOp::Make(v_lower, const_oper);
p_upper_bound = FracOp::Make(v_upper, const_oper);
} else {
p_lower_bound = FracOp::Make(v_upper, const_oper);
p_upper_bound = FracOp::Make(v_lower, const_oper);
}
AddBaseAndSimplify(lower_bound, p_lower_bound);
AddBaseAndSimplify(upper_bound, p_upper_bound);
return true;
}
}
return false;
}
bool CasSimplifyMutator::GetOperandBound(Expr* lower_bound,
Expr* upper_bound,
Expr v,
bool unfold_const_bound) {
// only support simple operand of int, var and var's product with int
CHECK(lower_bound);
CHECK(upper_bound);
auto* v_int = v.As<IntImm>();
if (v_int) {
AddBaseAndSimplify(lower_bound, v);
AddBaseAndSimplify(upper_bound, v);
return true;
} else if (GetVarBound(lower_bound, upper_bound, v, unfold_const_bound)) {
return true;
}
return false;
}
bool CasSimplifyMutator::GetSumBound(Expr* lower_bound,
Expr* upper_bound,
Expr sum,
bool unfold_const_bound) {
// only support sum of int, var and var's product with int
CHECK(lower_bound);
CHECK(upper_bound);
auto bound_sum = sum.As<Sum>();
// CHECK(bound_sum);
bool get_bound = true;
Expr sum_lower_bound, sum_upper_bound;
if (bound_sum) {
for (Expr& v : bound_sum->operands()) {
if (!GetOperandBound(
&sum_lower_bound, &sum_upper_bound, v, unfold_const_bound)) {
get_bound = false;
break;
}
}
if (get_bound) {
*lower_bound = sum_lower_bound;
*upper_bound = sum_upper_bound;
}
return get_bound;
}
return false;
}
bool CasSimplifyMutator::GetExprBound(Expr* lower_bound,
Expr* upper_bound,
Expr expr,
bool unfold_const_bound) {
// only support min's operands as sum, int or var or var's product with int or
// min/max
auto bound_sum = expr.As<Sum>();
auto bound_min = expr.As<Min>();
auto bound_max = expr.As<Max>();
bool get_bound = true;
if (bound_sum) {
get_bound = GetSumBound(lower_bound, upper_bound, expr, unfold_const_bound);
} else if (bound_min) {
get_bound = GetMinBound(lower_bound, upper_bound, expr, unfold_const_bound);
} else if (bound_max) {
get_bound = GetMaxBound(lower_bound, upper_bound, expr, unfold_const_bound);
} else if (!GetOperandBound(
lower_bound, upper_bound, expr, unfold_const_bound)) {
return false;
}
return get_bound;
}
bool CasSimplifyMutator::GetMinBound(Expr* lower_bound,
Expr* upper_bound,
Expr min,
bool unfold_const_bound) {
// only support min's operands as sum, int or var or var's product with int or
// min/max
auto bound_min = min.As<Min>();
CHECK(bound_min);
bool get_bound = true;
Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound;
get_bound =
get_bound &&
GetExprBound(
&a_lower_bound, &a_upper_bound, bound_min->a(), unfold_const_bound) &&
GetExprBound(
&b_lower_bound, &b_upper_bound, bound_min->b(), unfold_const_bound);
if (get_bound) {
*lower_bound =
CasSimplify(Min::Make(a_lower_bound, b_lower_bound), var_intervals);
*upper_bound =
CasSimplify(Min::Make(a_upper_bound, b_upper_bound), var_intervals);
}
return get_bound;
}
bool CasSimplifyMutator::GetMaxBound(Expr* lower_bound,
Expr* upper_bound,
Expr max,
bool unfold_const_bound) {
auto bound_max = max.As<Max>();
CHECK(bound_max);
bool get_bound = true;
Expr a_lower_bound, a_upper_bound, b_lower_bound, b_upper_bound;
get_bound =
get_bound &&
GetExprBound(
&a_lower_bound, &a_upper_bound, bound_max->a(), unfold_const_bound) &&
GetExprBound(
&b_lower_bound, &b_upper_bound, bound_max->b(), unfold_const_bound);
if (get_bound) {
*lower_bound =
CasSimplify(Max::Make(a_lower_bound, b_lower_bound), var_intervals);
*upper_bound =
CasSimplify(Max::Make(a_upper_bound, b_upper_bound), var_intervals);
}
return get_bound;
}
bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) {
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32)
auto a_sum = a.As<Sum>();
auto b_i = b.As<IntImm>();
if (!a_sum || !b_i) {
return false;
}
// if 0 < b < 3, (3a+b) % 6 = (3a % 6) + (b % 6)
if (a_sum->operands().size() == 2) {
a_sum->operands()[0] = CasSimplify(a_sum->operands()[0], var_intervals);
auto sum_a_prod = a_sum->operands()[0].As<Product>();
auto sum_b_var = a_sum->operands()[1].As<_Var_>();
if (sum_a_prod && sum_b_var && var_intervals.count(sum_b_var->name)) {
auto sum_a_prod_b_int = sum_a_prod->operand(1).As<IntImm>();
if (sum_a_prod_b_int)
std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1));
auto sum_a_prod_a_int = sum_a_prod->operand(0).As<IntImm>();
auto& interval = var_intervals.at(sum_b_var->name);
int b_abs = std::abs(b_i->value);
int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value);
if (sum_a_prod_a_int && (b_abs % sum_prod_a_abs == 0)) {
if (std::abs(interval.l) < sum_prod_a_abs &&
std::abs(interval.r) < sum_prod_a_abs) {
*result = CasSimplify(
Sum::Make({CasSimplify(Mod::Make(a_sum->operands()[0], b),
var_intervals),
CasSimplify(Mod::Make(a_sum->operands()[1], b),
var_intervals)}),
var_intervals);
return true;
}
}
}
}
#ifdef CINN_WITH_CUDA
return false;
#else
int const_value = 0;
Expr lower_bound;
Expr upper_bound;
Expr rest_oper;
bool can_simplify = true;
bool has_int = false;
// fold only the expr bound(may contains the var) and try to simplify the var
Expr unfolded_lower_bound, unfolded_upper_bound;
for (Expr& v : a_sum->operands()) {
auto* v_int = v.As<IntImm>();
if (v_int) {
const_value += v_int->value;
has_int = true;
} else if (GetVarBound(&lower_bound, &upper_bound, v, false)) {
AddBaseAndSimplify(&rest_oper, v);
} else {
can_simplify = false;
break;
}
}
can_simplify = can_simplify && has_int &&
std::abs(const_value) % b_i->value == b_i->value - 1 &&
lower_bound.defined() && upper_bound.defined() &&
rest_oper.defined();
// further infer the vars' bound by the intervals infos, try to get the
// constant
if (can_simplify) {
std::vector<Expr> bounds = {lower_bound, upper_bound};
for (int i = 0; i < bounds.size(); ++i) {
Expr bound = bounds[i];
Expr bound_l, bound_r;
GetExprBound(&bound_l, &bound_r, bound);
if (i == 0 && bound_l.defined()) {
lower_bound = bound_l;
}
if (i == 1 && bound_r.defined()) {
upper_bound = bound_r;
}
}
} else {
return false;
}
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32)
can_simplify = can_simplify && lower_bound.is_constant();
bool case1 = can_simplify && const_value >= 0 &&
lower_bound.get_constant() >= -const_value &&
upper_bound.is_constant() && upper_bound.get_constant() <= 0;
bool case2 = can_simplify && const_value <= 0 &&
lower_bound.get_constant() >= 0 && upper_bound.is_constant() &&
upper_bound.get_constant() <= -const_value;
can_simplify = can_simplify && (case1 || case2);
if (can_simplify) {
Expr const_expr;
if (const_value < 0) {
const_expr = make_const(b->type(), const_value % b_i->value);
} else {
const_expr = make_const(b->type(), const_value % b_i->value);
}
*result = CasSimplify(
Sum::Make(
{const_expr, CasSimplify(Mod::Make(rest_oper, b), var_intervals)}),
var_intervals);
return true;
}
return false;
#endif
}
// Return if the var's interval is nonnegative.
inline bool IsVarNonnegative(
const absl::flat_hash_map<std::string, CasInterval>& var_intervals,
const std::string& var_name) {
return var_intervals.count(var_name) && var_intervals.at(var_name).l >= 0;
}
// Return if the var is binded with thread or block in cuda(which implies it is
// non-negative).
inline bool IsVarBinded(const std::string& var_name) {
return utils::Startswith(var_name, "threadIdx") ||
utils::Startswith(var_name, "blockIdx");
}
/**
* Return if exprs are still all nonnegative vars.
* @param all_nonnegative_var is previous exprs all nonnegative vars.
* @param arg_var the pointer of this var.
* @param var_intervals intervals of each var.
* @return if exprs are still all nonnegative vars.
*/
inline bool IsVarAllNonnegative(
bool all_nonnegative_var,
_Var_* arg_var,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
// All exprs all nonnegative vars if previous exprs are nonnegative
// vars(all_nonnegative_var == true) and this expr is a var (arg_var !=
// nullptr) and (this var's interval is nonnegative or this var is binded to
// thread or block in cuda).
return all_nonnegative_var && arg_var &&
(IsVarNonnegative(var_intervals, arg_var->name) ||
IsVarBinded(arg_var->name));
}
Expr CasSimplifyMutator::SimplifyMod(Expr u) {
VLOG(4) << "SimplifyMod:" << u;
auto* node = u.As<Mod>();
CHECK(node);
auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);
auto* a_i = a.As<IntImm>();
auto* a_product = a.As<Product>();
auto* a_sum = a.As<Sum>();
auto* a_var = a.As<_Var_>();
auto* a_mod = a.As<Mod>();
auto* a_add = a.As<Add>();
auto* b_i = b.As<IntImm>();
// 7 % 3
if (a_i && b_i) {
return make_const(a_i->type(), a_i->value % b_i->value);
}
// x % 1 = 0
if (b_i && b_i->value == 1) return make_const(b_i->type(), 0);
// handle cases:
// (x * 6) % 2 = 0
// (x * 2) % 6 = (x % 3) * 2
if (b_i && a_product && b_i->value > 0) {
for (int i = 0; i < a_product->operands().size(); i++) {
auto a_op_i = a_product->operand(i);
if (a_op_i.As<IntImm>() && a_op_i.As<IntImm>()->value > 0) {
int a_op_int = a_op_i.As<IntImm>()->value;
// case: (x * 6) % 2 = 0
if (a_op_int % b_i->value == 0) return make_const(a_product->type(), 0);
// case: (x * y * 2) % 6 = ((x * y) % 3) * 2
if (b_i->value % a_op_int == 0) {
int new_b = b_i->value / a_op_int;
std::vector<Expr> a_operands = a_product->operands();
a_operands.erase(a_operands.begin() + i);
return Product::Make(
{SimplifyMod(Mod::Make(Product::Make(a_operands), Expr(new_b))),
Expr(a_op_int)});
}
}
}
}
// (x % 16) % 4 = x % 4
if (a_mod && b_i) {
VLOG(4) << "Simplify sequential mod";
auto* a_b_i = a_mod->b().As<IntImm>();
if (a_b_i->value != 0 && a_b_i->value % b_i->value == 0) {
auto e = SimplifyMod(Mod::Make(a_mod->a(), b_i));
VLOG(4) << "Reduce Mod from " << u << " to " << e;
return e;
}
}
// 0 % x = 0, 1 % x = 1
if (a_i && (a_i->value == 0 || a_i->value == 1)) return a;
if (b_i && a_var && var_intervals.count(a_var->name)) {
auto& interval = var_intervals.at(a_var->name);
int b_abs = std::abs(b_i->value);
// x\in[1, 3] % 4 = x
if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs) return a;
// [3,3] % 3 = 0
if (interval.l == interval.r && interval.l % b_abs == 0)
return make_const(b_i->type(), 0);
}
if (a_product && b_i) {
if (IsDivisible(a_product, b_i->value)) {
return make_const(Int(32), 0);
}
}
// (4*x + k*y)%2 = (k*y) %2
// (2x+y+z) % 2 = (y+z) % 2
if (a_sum && b_i) {
VLOG(4) << "A SUM ";
std::vector<Expr> sum_args;
for (auto& v : a_sum->operands()) {
if (!IsDivisible(v, b_i->value)) {
VLOG(4) << v;
sum_args.push_back(v);
}
}
if (sum_args.empty()) return make_const(b_i->type(), 0);
// handle the case: (2x+y+z) % 2 = (y+z) % 2 when y>=0 and z>=0
if (sum_args.size() == 1) {
return SimplifyMod(Mod::Make(sum_args[0], b));
} else if (sum_args.size() < a_sum->operands().size()) {
bool all_nonnegative_var = true;
bool all_nonnegative_int = true;
for (int i = 0; i < sum_args.size(); i++) {
auto* arg_var = sum_args[i].As<_Var_>();
all_nonnegative_var =
IsVarAllNonnegative(all_nonnegative_var, arg_var, var_intervals);
auto* arg_int = sum_args[i].As<IntImm>();
all_nonnegative_int =
all_nonnegative_int && arg_int && arg_int->value >= 0;
}
VLOG(4) << all_nonnegative_var << " " << all_nonnegative_int;
if (all_nonnegative_var)
return SimplifyMod(Mod::Make(Sum::Make(sum_args), b));
if (all_nonnegative_int) {
int sum_value = 0;
for (auto& i : sum_args) sum_value += i.As<IntImm>()->value;
return make_const(a_sum->type(), sum_value % b_i->value);
}
return SimplifyMod(Mod::Make(Sum::Make(sum_args), b));
} else if (sum_args.size() == a_sum->operands().size()) {
if (b_i->value > 0 && !var_intervals.empty()) {
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32))%33 = x%33 - 32%33 (0<=x<=32)
Expr result;
if (SimplifySpecificSumMod(&result, a, b)) {
return result;
}
}
return Mod::Make(a, b);
}
}
return Mod::Make(a, b);
}
Expr CasSimplifyMutator::SimplifyMinAndMax(Expr u) {
// simplify min/max
auto* u_max = u.As<Max>();
auto* u_min = u.As<Min>();
if (u_max) {
Expr a = CasSimplify(u_max->a(), var_intervals);
Expr b = CasSimplify(u_max->b(), var_intervals);
bool is_a_const = a.is_constant();
bool is_b_const = b.is_constant();
if (is_a_const && is_b_const) {
return a.get_constant() >= b.get_constant() ? a : b;
}
Expr lower_bound, upper_bound;
Expr const_operand, non_const_operand;
if (is_a_const) {
const_operand = a;
non_const_operand = b;
}
if (is_b_const) {
const_operand = b;
non_const_operand = a;
}
if (const_operand.defined() && non_const_operand.defined()) {
auto const_size = const_operand.get_constant();
// unfold var with bounds
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, true)) {
// if non_const_operand's lower_bound is larger than const_operand, then
// non_const_operand must be larger than const_operand
if (lower_bound.is_constant() &&
const_size <= lower_bound.get_constant()) {
return non_const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then
// const_operand must be larger than non_const_operand
if (upper_bound.is_constant() &&
const_size >= upper_bound.get_constant()) {
return const_operand;
}
}
// not unfold var for var may be eliminated in the caculation
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, false)) {
// if non_const_operand's lower_bound is larger than const_operand, then
// non_const_operand must be larger than const_operand
lower_bound = CasSimplify(lower_bound, var_intervals);
upper_bound = CasSimplify(upper_bound, var_intervals);
if (lower_bound.is_constant() &&
const_size <= lower_bound.get_constant()) {
return non_const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then
// const_operand must be larger than non_const_operand
if (upper_bound.is_constant() &&
const_size >= upper_bound.get_constant()) {
return const_operand;
}
}
}
return ir::Max::Make(a, b);
}
if (u_min) {
Expr a = CasSimplify(u_min->a(), var_intervals);
Expr b = CasSimplify(u_min->b(), var_intervals);
bool is_a_const = a.is_constant();
bool is_b_const = b.is_constant();
if (is_a_const && is_b_const) {
return a.get_constant() <= b.get_constant() ? a : b;
}
Expr lower_bound, upper_bound;
Expr const_operand, non_const_operand;
if (is_a_const) {
const_operand = a;
non_const_operand = b;
}
if (is_b_const) {
const_operand = b;
non_const_operand = a;
}
if (const_operand.defined() && non_const_operand.defined()) {
auto const_size = const_operand.get_constant();
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, true)) {
// if non_const_operand's lower_bound is larger than const_operand, then
// non_const_operand must be larger than const_operand
if (lower_bound.is_constant() &&
const_size <= lower_bound.get_constant()) {
return const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then
// const_operand must be larger than non_const_operand
if (upper_bound.is_constant() &&
const_size >= upper_bound.get_constant()) {
return non_const_operand;
}
}
if (GetExprBound(&lower_bound, &upper_bound, non_const_operand, false)) {
// if non_const_operand's lower_bound is larger than const_operand, then
// non_const_operand must be larger than const_operand
if (lower_bound.is_constant() &&
const_size <= lower_bound.get_constant()) {
return const_operand;
}
// if non_const_operand's upper_bound is smaller than a, then
// const_operand must be larger than non_const_operand
if (upper_bound.is_constant() &&
const_size >= upper_bound.get_constant()) {
return non_const_operand;
}
}
}
return ir::Min::Make(a, b);
}
return u;
}
Expr CasSimplifyMutator::SimplifyCmp(Expr u) {
Expr a = operator()(u->operand(0));
Expr b = operator()(u->operand(1));
if (a.is_constant() && b.is_constant()) {
switch (u->node_type()) {
case ir::IrNodeTy::LT:
return Expr(a.get_constant() < b.get_constant());
case ir::IrNodeTy::LE:
return Expr(a.get_constant() <= b.get_constant());
case ir::IrNodeTy::GT:
return Expr(a.get_constant() > b.get_constant());
case ir::IrNodeTy::GE:
return Expr(a.get_constant() >= b.get_constant());
case ir::IrNodeTy::EQ:
return Expr(a.get_constant() == b.get_constant());
case ir::IrNodeTy::NE:
return Expr(a.get_constant() != b.get_constant());
}
}
return u;
}
/**
* deal with index's div-mod add simplification, tempory solution, not cover all
* situations. case 1: (m / n) * n + m % n = m (m, n's type is int) case 2: (m /
* n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2, n3's type is
* int)
*/
Expr CasSimplifyMutator::SimplifySpecificSum(Expr tmp) {
auto sum = tmp.As<Sum>();
if (!sum) {
return tmp;
}
if (sum->operands().size() == 1U) return sum->operand(0);
Expr left = sum->operand(0);
Expr right = sum->operand(1);
auto left_mod = left.As<Mod>();
auto right_mod = right.As<Mod>();
auto left_mul = left.As<Product>();
auto right_mul = right.As<Product>();
auto left_div = left.As<FracOp>();
auto right_div = right.As<FracOp>();
// normalize to left mul and right mod
if (right_mul && left_mod) {
left_mul = right_mul;
right_mod = left_mod;
}
// normalize to left div and right mod
if (right_div && left_mod) {
left_div = right_div;
right_mod = left_mod;
}
if (!right_mod || (!left_mul && !left_div)) {
return tmp;
}
CHECK_GE(right_mod->operands().size(), 2U);
Expr mod_left = right_mod->operand(0);
Expr mod_right = right_mod->operand(1);
if (!mod_left->type().is_integer() || !mod_right->type().is_integer()) {
return tmp;
}
if (left_mul) {
// case 1: (m / n) * n + m % n = m (m, n's type is int)
// case 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1,
// n2, n3's type is int)
CHECK_GE(left_mul->operands().size(), 2U);
Expr mul_left = left_mul->operand(0);
Expr mul_right = left_mul->operand(1);
// handle the case1 : n * (m / n) + m % n = (m / n) * n + m % n = m
// handle the case2 : n3 * (m / n1) + (n2 * m) % n3 = (m / n1) * n3 + (n2 *
// m) % n3 = n2 * m if n3 = n1 * n2
if (MathEqual(mod_right, mul_left)) {
mul_left = left_mul->operand(1);
mul_right = left_mul->operand(0);
} else if (!MathEqual(mod_right, mul_right)) {
return tmp;
}
auto div = mul_left.As<FracOp>();
if (!div) {
return tmp;
}
CHECK_GE(div->operands().size(), 2U);
Expr div_left = div->operand(0);
Expr div_right = div->operand(1);
if (!div_left->type().is_integer() || !div_right->type().is_integer()) {
return tmp;
}
if (MathEqual(div_left * mod_right, mod_left * div_right)) {
tmp = mod_left;
for (int i = 2; i < sum->operands().size(); i++) {
tmp = tmp + sum->operand(i);
}
return tmp;
}
}
return tmp;
}
Expr CasSimplifyMutator::operator()(Expr u) {
if (u.As<Min>() || u.As<Max>()) {
return SimplifyMinAndMax(u);
}
u = detail::SumOrProductGetSingleElementsRec(u);
if (u.is_constant() || u.As<_Var_>()) return u;
if (u.As<FracOp>()) {
u = SimplifyFracOp(u);
auto tmp = FurtherSimplifyFracWithInterval(u, var_intervals);
if (!tmp.same_as(u)) return operator()(tmp);
return u;
}
if (u.As<Product>()) {
return detail::SumOrProductGetSingleElementsRec(SimplifyProduct(u));
}
if (u.As<Sum>()) {
auto tmp = detail::SumOrProductGetSingleElementsRec(SimplifySum(u));
// deal with index's div-mod add simplification, tempory solution, not cover
// all situations. case 1: (m / n) * n + m % n = m (m, n's type is int) case
// 2: (m / n1) * n3 + (n2 * m) % n3 = n2 * m if n3 = n1 * n2 (m, n1, n2,
// n3's type is int) case 3: m / n2 + (n1 * m) % n3 = n1 * m if n3 = n1 * n2
// (m, n1, n2, n3's type is int)
return SimplifySpecificSum(tmp);
}
if (u.As<Mod>()) {
return detail::SumOrProductGetSingleElementsRec(SimplifyMod(u));
}
if (u.is_cmp()) {
return SimplifyCmp(u);
}
switch (u.node_type()) {
case ir::IrNodeTy::And:
case ir::IrNodeTy::Or:
case ir::IrNodeTy::Not:
return SimplifyCond(u);
default:
break;
}
return u;
}
bool CASasSymbol(Expr expr) {
auto* load_n = expr.As<Load>();
auto* var_n = expr.As<_Var_>();
auto* broadcast_n = expr.As<Broadcast>();
return load_n || var_n || broadcast_n;
}
Expr ConvertCinnToCAS(Expr expr) {
VLOG(7) << "Begin ConvertCinnToCAS " << expr;
Expr copied = optim::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
private:
void Visit(const Add* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
bool is_zero_a = a.is_constant() && a.get_constant() == 0;
bool is_zero_b = b.is_constant() && b.get_constant() == 0;
if (is_zero_a) {
*expr = b;
return;
} else if (is_zero_b) {
*expr = a;
return;
}
*expr = Sum::Make({a, b});
}
void Visit(const Mul* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
if (a.is_constant() && a.get_constant() == 0) {
*expr = make_const(a->type(), 0);
return;
}
if (a.is_constant() && a.get_constant() == 1) {
*expr = b;
return;
}
if (b.is_constant() && b.get_constant() == 0) {
*expr = make_const(b->type(), 0);
return;
}
if (b.is_constant() && b.get_constant() == 1) {
*expr = a;
return;
}
*expr = Product::Make({a, b});
}
void Visit(const Sub* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
bool is_zero_a = a.is_constant() && a.get_constant() == 0;
bool is_zero_b = b.is_constant() && b.get_constant() == 0;
if (is_zero_a) {
*expr = Product::Make({make_const(b->type(), -1), b});
return;
} else if (is_zero_b) {
*expr = a;
return;
}
b = Product::Make({make_const(b->type(), -1), b});
*expr = Sum::Make({a, b});
}
void Visit(const Div* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
CHECK(!is_zero(b)) << "Dividend should not be zero";
if (a.is_constant() && a.get_constant() == 0) {
*expr = make_const(a->type(), 0);
return;
}
if (b.is_constant() && b.get_constant() == 1) {
*expr = a;
return;
}
// int division, NOTE that 3/2 = 1, 3./2 = 1.5
*expr = FracOp::Make(a, b);
}
void Visit(const Minus* op, Expr* expr) override {
auto a = op->v();
Visit(&a);
if (a.is_constant()) {
auto value = a.get_constant();
if (value == 0) {
*expr = make_const(a->type(), 0);
return;
}
}
*expr = Product::Make({make_const(a->type(), -1), a});
}
};
Mutator()(&copied);
return copied;
}
/**
* @brief Given an expr, visit it. If there is an ir::Min and its operands are 1
* constant value and 1 inconstant value, return the constant min value. For
* example, if a < min(5, b), then we get a < 5 and a < b. Using a < 5 to
* simplify the condition ensures correctness, though not sufficient.
*/
Expr ReplaceMinToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
private:
void Visit(const Min* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
auto min_a = op->a();
auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer());
*expr = optim::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer());
*expr = optim::IRCopy(min_b);
}
}
};
Mutator()(&copied);
return copied;
}
/**
* @brief Given an expr, visit it. If there is an ir::Max and its operands are 1
* constant value and 1 inconstant value, return the constant max value.
*/
Expr ReplaceMaxToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
private:
void Visit(const Max* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
auto max_a = op->a();
auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer());
*expr = optim::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer());
*expr = optim::IRCopy(max_b);
}
}
};
Mutator()(&copied);
return copied;
}
Expr ConvertCasToCinn(Expr expr) {
VLOG(7) << "Begin ConvertCasToCinn : " << expr;
Expr copied = optim::IRCopy(expr);
struct Mutator : ir::IRMutator<Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
private:
void Visit(const Product* op, Expr* expr) override {
std::vector<Expr> operands;
auto* node = expr->As<Product>();
for (auto& v : node->operands()) {
auto c = v;
Mutator()(&c);
operands.push_back(c);
}
CHECK(!operands.empty());
if (operands.size() == 1) {
*expr = operands[0];
} else if (operands.size() == 2) {
*expr = Mul::Make(operands[0], operands[1]);
} else {
auto a = operands[0];
auto b = Product::Make(EraseFront(operands));
Mutator()(&b);
*expr = Mul::Make(a, b);
}
// process the Mul
Visit(expr);
}
void Visit(const Sum* op, Expr* expr) override {
std::vector<Expr> operands;
auto* node = expr->As<Sum>();
for (auto& v : node->operands()) {
auto c = v;
Mutator()(&c);
operands.push_back(c);
}
CHECK(!operands.empty());
if (operands.size() == 1) {
*expr = operands[0];
} else if (operands.size() == 2) {
*expr = Add::Make(operands[0], operands[1]);
} else {
auto a = operands[0];
auto b = Sum::Make(EraseFront(operands));
Mutator()(&b);
*expr = Add::Make(a, b);
}
// process the sum
Visit(expr);
}
void Visit(const FracOp* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
CHECK(!is_zero(b)) << "Dividend should not be zero";
*expr = Div::Make(a, b);
Visit(expr);
}
// a + -1*b -> a-b
void Visit(const Add* op, Expr* expr) override {
auto a = op->a();
auto b = op->b();
Visit(&a);
Visit(&b);
auto* bp = b.As<ir::Mul>();
if (bp && bp->a().is_constant() && bp->a().get_constant() == -1.f) {
*expr = Sub::Make(a, bp->b());
} else {
*expr = Add::Make(a, b);
}
}
};
Mutator()(&copied);
return copied;
}
bool IsExprCasCompatible(Expr expr) {
auto teller = [](const Expr* expr) {
return expr->As<Add>() || expr->As<Sub>() || expr->As<Mul>() ||
expr->As<Div>();
};
return ir::CollectIRNodes(expr, teller).empty();
}
// Partially divide a by b. e.g. (2x+y)/2 => x + y/2
Expr DividePartially(Sum* a, int b) {
std::vector<Expr> external_sum_args, sum_args;
for (auto& item : a->operands()) {
if (item.As<Product>() && (IsDivisible(item.As<Product>(), b) ||
IsDivisible(b, item.As<Product>()))) {
external_sum_args.push_back(Divide(item.As<Product>(), b));
} else if (item.As<IntImm>() && IsDivisible(item.As<IntImm>()->value, b)) {
external_sum_args.push_back(
make_const(item.type(), item.As<IntImm>()->value / b));
} else {
sum_args.push_back(item);
}
}
if (!external_sum_args.empty()) {
if (sum_args.empty()) return Sum::Make(external_sum_args);
Expr internal_sum =
sum_args.size() == 1 ? sum_args[0] : Sum::Make(sum_args);
Expr new_frac = FracOp::Make(internal_sum, make_const(a->type(), b));
return Sum::Make(Concat(external_sum_args, {new_frac}));
}
return Expr(a);
}
bool IsMonotonical(Expr u, Var v) {
auto* up = u.As<Product>();
auto* uv = u.As<_Var_>();
if (uv && uv->name == v->name) return true;
if (up) {
for (auto& item : up->operands()) {
if (IsMonotonical(item, v)) return true;
}
}
return false;
}
// Should be called after SimplifyFracOp. If y is integer and $y\in \[0, 3\]$,
// then y/4=0
Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval(
Expr expr,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
auto* node = expr.As<FracOp>();
if (!node) return expr;
auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);
auto* ai = a.As<IntImm>();
auto* bi = b.As<IntImm>();
auto* av = a.As<_Var_>();
auto* bv = b.As<_Var_>();
auto* ap = a.As<Product>();
// case: y / 4, y\in[0,3]
if (bi) {
if (av) {
auto it = var_intervals.find(av->name);
if (it != var_intervals.end() &&
std::abs(it->second.r) < std::abs(bi->value) &&
std::abs(it->second.l) < std::abs(bi->value))
return make_const(a.type(), 0);
}
}
// case: 1/y, y\in(2, 100)
if (ai) {
if (bv) {
auto it = var_intervals.find(bv->name);
auto ai_abs = std::abs(ai->value);
if (it != var_intervals.end()) {
VLOG(7) << "found " << bv->name << " " << it->second << " "
<< " ai " << ai_abs;
}
if (it != var_intervals.end() && std::abs(it->second.r) > ai_abs &&
std::abs(it->second.l) > ai_abs) {
return make_const(a.type(), 0);
}
}
}
return expr;
}
Expr SimplifyConstantFrac(FracOp* node) {
auto* ai = node->a().As<ir::IntImm>();
auto* au = node->a().As<ir::UIntImm>();
auto* af = node->a().As<ir::FloatImm>();
if (ai) {
auto* bi = node->b().As<ir::IntImm>();
CHECK(bi);
return make_const(ai->type(), ai->value / bi->value);
}
if (au) {
auto* bu = node->b().As<ir::UIntImm>();
CHECK(bu);
return make_const(au->type(), au->value / bu->value);
}
if (af) {
auto* bf = node->b().As<ir::FloatImm>();
CHECK(af);
return make_const(af->type(), af->value / bf->value);
}
CINN_NOT_IMPLEMENTED
return Expr();
}
Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
VLOG(7) << "CAS simplify Frac " << expr;
auto* node = expr.As<FracOp>();
auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);
// update frac op node
expr = ir::FracOp::Make(a, b);
node = expr.As<FracOp>();
auto* ap = a.As<Product>();
auto* bp = b.As<Product>();
auto* as = a.As<Sum>();
auto* bi = b.As<IntImm>();
auto* ai = a.As<IntImm>();
auto* af = a.As<FloatImm>();
auto* bf = b.As<FloatImm>();
auto* av = a.As<_Var_>();
auto* bv = b.As<_Var_>();
// case 1
// integer constant division: 64/3
if (node->is_constant()) {
if (int_compute_) {
return SimplifyConstantFrac(node);
} else {
return SimplifyRationalNumber(expr);
}
}
// case 2
// sum/x or product/x is divisible
if (bi) {
auto* a_sum = a.As<Sum>();
auto* a_product = a.As<Product>();
// divisible
if (a_sum && IsDivisible(a_sum, bi->value)) return Divide(a_sum, bi->value);
if (a_product) {
if (IsDivisible(a_product, bi->value) ||
IsDivisible(bi->value, a_product)) {
return Divide(a_product, bi->value);
} else {
return FracOp::Make(a, b);
}
}
// if 0 < b < 3, (3a+b) / 6 = (3a / 6) + (b / 6)
if (a_sum && a_sum->operands().size() == 2) {
a_sum->operands()[0] = CasSimplify(a_sum->operands()[0], var_intervals);
auto sum_a_prod = a_sum->operands()[0].As<Product>();
auto sum_b_var = a_sum->operands()[1].As<_Var_>();
if (sum_a_prod && sum_b_var && var_intervals.count(sum_b_var->name)) {
auto sum_a_prod_b_int = sum_a_prod->operand(1).As<IntImm>();
if (sum_a_prod_b_int)
std::swap(sum_a_prod->operand(0), sum_a_prod->operand(1));
auto sum_a_prod_a_int = sum_a_prod->operand(0).As<IntImm>();
auto& interval = var_intervals.at(sum_b_var->name);
int b_abs = std::abs(bi->value);
int sum_prod_a_abs = std::abs(sum_a_prod_a_int->value);
if (sum_a_prod_a_int && (b_abs % sum_prod_a_abs == 0)) {
if (std::abs(interval.l) < sum_prod_a_abs &&
std::abs(interval.r) < sum_prod_a_abs) {
return CasSimplify(
Sum::Make({CasSimplify(FracOp::Make(a_sum->operands()[0], b),
var_intervals),
CasSimplify(FracOp::Make(a_sum->operands()[1], b),
var_intervals)}),
var_intervals);
}
}
}
}
// not divisible
/*
if (a_sum) {
auto expr = DividePartially(a_sum, bi->value);
return expr;
}
*/
}
// cinn_min/cinn_max(a, b)/2 = cinn_min/cinn_max(a/2, b/2)
if ((bi && bi->value > 0) || (bf && bf->value > 0)) {
auto cmp_min = a.As<Min>();
auto cmp_max = a.As<Max>();
if (cmp_min) {
return {CasSimplify(
Min::Make(CasSimplify(FracOp::Make(cmp_min->a(), b), var_intervals),
CasSimplify(FracOp::Make(cmp_min->b(), b), var_intervals)),
var_intervals)};
}
if (cmp_max) {
return {CasSimplify(
Max::Make(CasSimplify(FracOp::Make(cmp_max->a(), b), var_intervals),
CasSimplify(FracOp::Make(cmp_max->b(), b), var_intervals)),
var_intervals)};
}
}
if (av && bi) {
if (var_intervals.count(av->name)) {
auto& interval = var_intervals.at(av->name);
int b_abs = std::abs(bi->value);
if (std::abs(interval.l) < b_abs && std::abs(interval.r) < b_abs)
return make_const(bi->type(), 0);
return FracOp::Make(a, b);
}
}
// (32x+y)/32 = x + y/32
if (as && bi) {
std::vector<Expr> external_sum_args;
std::vector<Expr> internal_sum_args;
for (auto& e : as->operands()) {
if (IsDivisible(e, bi->value)) {
if (e.As<Sum>())
external_sum_args.push_back(Divide(e.As<Sum>(), bi->value));
if (e.As<IntImm>())
external_sum_args.push_back(
make_const(bi->type(), e.As<IntImm>()->value / bi->value));
if (e.As<Product>())
external_sum_args.push_back(Divide(e.As<Product>(), bi->value));
} else {
internal_sum_args.push_back(e);
}
}
Expr external_sum, internal_sum;
if (!external_sum_args.empty()) {
if (external_sum_args.size() == 1)
external_sum = external_sum_args.front();
else
external_sum = Sum::Make(external_sum_args);
}
if (!internal_sum_args.empty()) {
internal_sum = FracOp::Make(Sum::Make(internal_sum_args), b);
}
if (external_sum.defined() && internal_sum.defined()) {
return CasSimplify(Sum::Make({external_sum, internal_sum}),
var_intervals);
}
if (external_sum.defined()) return CasSimplify(external_sum, var_intervals);
return internal_sum;
}
// solve the case: 2abc / b
// Both avs and bvs should be sorted first.
auto reduce_product_div_product = [](const std::vector<Expr>& avs,
const std::vector<Expr>& bvs) {
std::vector<Expr> avs1, bvs1;
int i = 0;
int j = 0;
ExprPosCmp cmp;
while (i < avs.size() && j < bvs.size()) {
auto& a = avs[i];
auto& b = bvs[j];
if (a.is_constant() && b.is_constant()) {
auto* ai = a.As<IntImm>();
auto* bi = b.As<IntImm>();
auto* af = a.As<FloatImm>();
auto* bf = b.As<FloatImm>();
if (ai) {
CHECK(bi);
int g = gcd(ai->value, bi->value);
int a_d = ai->value / g;
int b_d = bi->value / g;
avs1.push_back(make_const(a.type(), a_d));
if (b_d != 1) bvs1.push_back(make_const(b.type(), b_d));
} else if (af || bf) {
double value = af->value / bf->value;
const auto& ftype = af ? af->type() : bf->type();
avs1.push_back(make_const(ftype, value));
} else {
avs1.push_back(a);
bvs1.push_back(b);
}
// CHECK(!af) << a << " " << b;
i++;
j++;
} else if (avs[i] == bvs[j]) {
i++;
j++;
} else {
// <
if (cmp(avs[i], bvs[j])) {
avs1.push_back(avs[i++]);
} else {
bvs1.push_back(bvs[j++]);
}
}
}
while (i < avs.size()) {
avs1.push_back(avs[i++]);
}
while (j < bvs.size()) {
bvs1.push_back(bvs[j++]);
}
if (avs1.empty()) return make_const(avs[0].type(), 1);
if (bvs1.empty()) return Product::Make(avs1);
return FracOp::Make(Product::Make(avs1), Product::Make(bvs1));
};
{
// TODO(SunNy820828449): fix in future.
// std::vector<Expr> a_args, b_args;
// if (ap)
// a_args = ap->operands();
// else
// a_args.push_back(a);
// if (bp)
// b_args = bp->operands();
// else
// b_args.push_back(b);
// return reduce_product_div_product(a_args, b_args);
}
// x / x
if (a.type().is_int() && b.type().is_int() && av && bv) {
if (a == b) return make_const(a.type(), 1);
}
if (node->a().same_as(a) && node->b().same_as(b)) return expr;
return FracOp::Make(a, b);
}
Expr CasSimplifyMutator::SimplifyCond(Expr u) {
switch (u->node_type()) {
// -------------------------- NOT -----------------------------
case ir::IrNodeTy::Not: {
auto* node = u.As<ir::Not>();
Expr v = operator()(node->v());
switch (v.node_type()) {
// Not 1 = (1 == 0)
case ir::IrNodeTy::IntImm:
return Expr(v.As<IntImm>()->value == 0);
// Not Not v = v
case ir::IrNodeTy::Not:
return v;
// Not <= is >
case ir::IrNodeTy::LE:
return ir::GT::Make(v->operand(0), v->operand(1));
// Not < is >=
case ir::IrNodeTy::LT:
return ir::GE::Make(v->operand(0), v->operand(1));
// Not >= is <
case ir::IrNodeTy::GE:
return ir::LT::Make(v->operand(0), v->operand(1));
// Not > is <=
case ir::IrNodeTy::GT:
return ir::LE::Make(v->operand(0), v->operand(1));
default:
return ir::Not::Make(v);
}
} break;
// -------------------------- AND OR -----------------------------
case ir::IrNodeTy::And:
case ir::IrNodeTy::Or: {
Expr a = operator()(u->operand(0));
Expr b = operator()(u->operand(1));
if (a.is_constant() || b.is_constant()) {
if (u.As<ir::And>()) {
// 1 && b is b
if (a.As<ir::UIntImm>()) {
return a.As<ir::UIntImm>()->value ? b : Expr(false);
}
// a && 1 is a
if (b.As<ir::UIntImm>()) {
return b.As<ir::UIntImm>()->value ? a : Expr(false);
}
return ir::And::Make(a, b);
}
if (u.As<ir::Or>()) {
// 1 || b is 1
if (a.As<ir::UIntImm>()) {
return a.As<ir::UIntImm>()->value ? a : b;
}
// a || 1 is 1
if (b.As<ir::UIntImm>()) {
return b.As<ir::UIntImm>()->value ? b : a;
}
}
return ir::Or::Make(a, b);
}
return u;
}
default:
return u;
}
}
} // namespace detail
Expr CasSimplify(
Expr u,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
return detail::CasSimplifyMutator(var_intervals)(u);
}
Expr SolveInequality(Expr inequality, Var val) {
auto copied = AutoSimplify(inequality);
auto* le_n = copied.As<ir::LE>();
auto* lt_n = copied.As<ir::LT>();
auto* gt_n = copied.As<ir::GT>();
auto* ge_n = copied.As<ir::GE>();
Expr a, b;
#define __(x__) \
if (x__) { \
a = x__->a(); \
b = x__->b(); \
}
__(le_n)
__(lt_n)
__(gt_n)
__(ge_n)
#undef __
Expr all = AutoSimplify(a - b);
// if (common::IsPureMath(a) && common::IsPureMath(b)) {
if (true) {
auto _res_positive_ = common::Solve(a, b, val); // NOLINT
auto& res = std::get<0>(_res_positive_);
auto& positive = std::get<1>(_res_positive_);
// Simplify it with CAS to avoid random result from GiNac.
res = AutoSimplify(res);
res = common::cast(res, val->type());
if (le_n) {
if (positive) return ir::LE::Make(val, res);
return ir::GE::Make(val, res);
}
if (lt_n) {
if (positive) return ir::LT::Make(val, res);
return ir::GT::Make(val, res);
}
if (ge_n) {
if (positive) return ir::GE::Make(val, res);
return ir::LE::Make(val, res);
}
if (gt_n) {
if (positive) return ir::GT::Make(val, res);
return ir::LT::Make(val, res);
}
} else {
return AutoSimplify(inequality);
}
return Expr();
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <functional>
#include <string>
#include <vector>
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn {
namespace common {
namespace detail {
Expr ReplaceMinToConstant(Expr expr);
Expr ReplaceMaxToConstant(Expr expr);
} // namespace detail
/**
* Interval of a _Var_.
*/
struct CasInterval {
template <typename T>
CasInterval(T l, T r) : l(l), r(r) {
CHECK_LE(l, r) << "left should not be larger than right";
}
/**
* @brief When iterator's upper_bound is an ir::Min of a constant value and a
* inconstant value, choose the constant value. When iterator's lower_bound is
* an ir::Max of a constant value and a inconstant value, choose the constant
* value. E.g: expr_l = max(x, 1) and expr_r = min(y,5): max(x, 1) <=
* iterator_i <= min(y,5)
*
* the bounds will be simplified to e_l = 1 and e_r = 5:
* 1 <= iterator_i <= 5
*/
CasInterval(Expr expr_l, Expr expr_r) {
VLOG(2) << "CasInterval is : [" << expr_l << ", " << expr_r << "].";
expr_r = detail::ReplaceMinToConstant(expr_r);
expr_l = detail::ReplaceMaxToConstant(expr_l);
optim::Simplify(&expr_l);
optim::Simplify(&expr_r);
VLOG(2) << "After simplify, CasInterval is : [" << expr_l << ", " << expr_r
<< "].";
if (expr_l.is_constant() && expr_r.is_constant()) {
CHECK(expr_l->type().is_integer());
CHECK(expr_r->type().is_integer());
l = expr_l.as_int32();
r = expr_r.as_int32();
return;
}
e_l = expr_l;
e_r = expr_r;
}
int l, r;
// Note: not verify l <= r and (e_l, e_r) has higher priority than (l, r)
Expr e_l, e_r;
friend std::ostream& operator<<(std::ostream& os, const CasInterval& i) {
if (i.e_l.defined() && i.e_r.defined()) {
os << "Expr e_l Interval[" << i.e_l << ", " << i.e_r << "]";
} else {
os << "Int l Interval[" << i.l << ", " << i.r << "]";
}
return os;
}
};
using cas_intervals_t = absl::flat_hash_map<std::string, CasInterval>;
Expr AutoSimplify(
Expr u,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals = {});
//! Simplify a CAS expression.
Expr CasSimplify(
Expr u,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals = {});
/**
* \brief Solve an equality.
* Currently this is an naive implementation using the GiNaC.
*
* @param inequality The inequality expression containing an LE or LT or GT or
* GE, such as 2x-1<3
* @param val The target variable.
* @return an copied expression looks like x < 100.
*/
Expr SolveInequality(Expr inequality, Var val);
Expr SolveInequalityInt(Expr inequality, Var val);
namespace detail {
//! Whether to treat this expression as a symbol. e.g. Load, Min, Max are
//! treated as symbol to avoid confusing the CAS.
bool CASasSymbol(Expr expr);
//! Convert some nodes to CAS representation, e.g. convert Mul, Add to Product
//! and Sum.
Expr ConvertCinnToCAS(Expr expr);
//! Convert the CAS representation to CINN expression, e.g. convert Product and
//! Sum to Mul and Add.
Expr ConvertCasToCinn(Expr expr);
//! Tell whether this expression is acceptable by CAS.
bool IsExprCasCompatible(Expr expr);
struct ExprPosCmp {
bool operator()(const Expr& a, const Expr& b);
};
struct CasSimplifyMutator {
explicit CasSimplifyMutator(
const absl::flat_hash_map<std::string, CasInterval> var_intervals)
: var_intervals(var_intervals) {}
Expr operator()(Expr u);
Expr SimplifyRationalNumber(Expr u);
Expr SimplifyPower(Expr u);
Expr SimplifySum(Expr u);
Expr SimplifyProduct(Expr a);
Expr SimplifyMinAndMax(Expr a);
Expr SimplifyCmp(Expr a);
std::vector<Expr> SimplifyProductRec(const std::vector<Expr>& operands);
std::vector<Expr> SimplifySumRec(const std::vector<Expr>& operands);
Expr SimplifyMod(Expr u);
Expr SimplifyFracOp(Expr expr);
Expr SimplifyCond(Expr u);
Expr FurtherSimplifyFracWithInterval(
Expr expr,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals);
Expr SimplifyIntegerPower(Expr u);
void AddBaseAndSimplify(Expr* base, Expr bound);
void UnfoldBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound = true);
bool GetVarBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound = true);
bool GetOperandBound(Expr* lower_bound,
Expr* upper_bound,
Expr var,
bool unfold_const_bound = true);
bool GetSumBound(Expr* lower_bound,
Expr* upper_bound,
Expr sum,
bool unfold_const_bound = true);
bool GetMinBound(Expr* lower_bound,
Expr* upper_bound,
Expr min,
bool unfold_const_bound = true);
bool GetMaxBound(Expr* lower_bound,
Expr* upper_bound,
Expr max,
bool unfold_const_bound = true);
bool GetExprBound(Expr* lower_bound,
Expr* upper_bound,
Expr min,
bool unfold_const_bound = true);
bool SimplifySpecificSumMod(Expr* u, Expr a, Expr b);
Expr SimplifySpecificSum(Expr u);
private:
std::vector<Expr> SimplifyBinaryProduct(Expr left, Expr right);
std::vector<Expr> MergeProduct(const std::vector<Expr>& p,
const std::vector<Expr>& q);
std::vector<Expr> SimplifyBinarySum(Expr left, Expr right);
std::vector<Expr> MergeSum(const std::vector<Expr>& p,
const std::vector<Expr>& q);
std::vector<Expr> MergeExprs(
const std::vector<Expr>& p,
const std::vector<Expr>& q,
const std::function<std::vector<Expr>(Expr, Expr)>& binary_merge);
const absl::flat_hash_map<std::string, CasInterval> var_intervals;
// Computation based on integer if set true(1/2 get 0), false if treat as
// rational number in mathematics(1/2 is still 1/2), currently it only works
// with true.
bool int_compute_{true};
};
} // namespace detail
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/cas.h"
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace cinn {
namespace common {
using common::make_const;
using utils::GetStreamCnt;
using utils::Join;
using utils::Trim;
using namespace ir; // NOLINT
TEST(CAS, number_cal) {
// 1 * 100 * -1 + 0 + 1001
auto u1 = Sum::Make(
{Product::Make({Expr(1), Expr(100), Expr(-1)}), Expr(0), Expr(1001)});
LOG(INFO) << u1;
}
TEST(CAS, cmp) {
detail::ExprPosCmp cmp;
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
EXPECT_EQ(cmp(x, Expr(1)), false);
EXPECT_EQ(cmp(Expr(1), x), true);
// x * y * z > x * y
EXPECT_EQ(cmp(ir::Product::Make({x, y, z}), ir::Product::Make({x, y})),
false);
// x * y * z > 10 * y * z
EXPECT_EQ(
cmp(ir::Product::Make({x, y, z}), ir::Product::Make({Expr(10), y, z})),
false);
// 1 * y * z < 10 * y * z
EXPECT_EQ(cmp(ir::Product::Make({Expr(1), y, z}),
ir::Product::Make({Expr(10), y, z})),
true);
}
TEST(CAS, SimplifySum) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
// x + y + z + 0
auto u1 = Sum::Make({x, y, z, make_const(0)});
// x*1 + y + z + 0
auto u2 = Sum::Make({Product::Make({x, Expr(1)}), y, z, make_const(0)});
// z + 1 + y + x + zx
auto u3 = CasSimplify(Sum::Make({z, Expr(1), y, x, Product::Make({z, x})}));
// z + 1 + y + 3 + x + 0 + zx
auto u4 = CasSimplify(
Sum::Make({z, Expr(1), y, Expr(3), x, Expr(0), Product::Make({z, x})}));
// x2 + 3zy + -3*yz + -2x + 1
auto u5 = CasSimplify(Sum::Make({Product::Make({x, Expr(2)}),
Product::Make({z, y, Expr(3)}),
Product::Make({Expr(-3), y, z}),
Product::Make({Expr(-2), x}),
Expr(1)}));
EXPECT_EQ(GetStreamCnt(CasSimplify(u1)), "(x + y + z)");
EXPECT_EQ(GetStreamCnt(CasSimplify(u2)), "(x + y + z)");
EXPECT_EQ(GetStreamCnt(u3), "(1 + x + y + z + (x * z))");
EXPECT_EQ(GetStreamCnt(u4), "(4 + x + y + z + (x * z))");
EXPECT_EQ(GetStreamCnt(u5), "1");
}
TEST(CAS, SimplifyProduct) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
// zyx*(-1)
auto u2 = CasSimplify(Product::Make({z, y, x, Expr(-1)}));
EXPECT_EQ(GetStreamCnt(u2), "(-1 * x * y * z)");
}
TEST(CAS, SimplifyMod) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
// 2*x % 2 = 0
auto u1 = CasSimplify(Mod::Make(Product::Make({x, Expr(2)}), Expr(2)));
// (x+y+z) % 2 = x%2 + y%2 + z%2
auto u2 = CasSimplify(Mod::Make(Sum::Make({x, y, z}), Expr(2)));
// x%2 + 1%2 + x%2
auto u3 = CasSimplify(Sum::Make({Mod::Make(x, Expr(2)),
Mod::Make(Expr(1), Expr(2)),
Mod::Make(x, Expr(2))}));
EXPECT_EQ(GetStreamCnt(u1), "0");
EXPECT_EQ(GetStreamCnt(u2), "((x + y + z) % 2)");
EXPECT_EQ(GetStreamCnt(u3), "1");
}
TEST(CAS, SimplifyModForVectorize) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
// (((8*x + 1024*y) % 802816) % 7168) %64
// = (8*x + 1024*y) %64 // since 7168 and 802816 is k*64
// = (8*x) % 64 // since 1024 is k*64
// = (8*x - ((8*x) // 64) * 64 // since mod definition a%b = a - (a//b)*b
// = (8*x) - (x//8)*64
// = (8*x) - (x//8)*(8*8)
// = 8*(x-(x//8)*8) // since mod definition
// = 8*(x%8)
auto u1 = CasSimplify(
Mod::Make(Mod::Make(Mod::Make(Sum::Make({Product::Make({x, Expr(8)}),
Product::Make({y, Expr(1024)})}),
Expr(802816)),
Expr(7168)),
Expr(64)));
std::cout << GetStreamCnt(u1);
EXPECT_EQ(GetStreamCnt(u1), "((x % 8) * 8)");
}
TEST(CAS, ConvertCinnToCAS) {
Placeholder<float> A("A", {10, 10});
Placeholder<float> B("B", {10, 10});
auto C = Compute(
{Expr(10), Expr(10)},
[&](Expr i, Expr j) {
return A(i, j) + 0.f + 1.f + 2.f * B(i, j) + 0.f * B(i, j) * A(i, j);
},
"C");
Expr body = C->body();
LOG(INFO) << "body " << body;
body = detail::ConvertCinnToCAS(body);
body = CasSimplify(body);
EXPECT_EQ(GetStreamCnt(body),
"(1.00000000f + A[i, j] + (2.00000000f * B[i, j]))");
body = detail::ConvertCasToCinn(body);
EXPECT_EQ(GetStreamCnt(body),
"(1.00000000f + (A[i, j] + (2.00000000f * B[i, j])))");
}
TEST(CAS, FracOp) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
auto u1 = AutoSimplify(Div::Make(Expr(1), x) * x);
EXPECT_EQ(GetStreamCnt(u1), "((1 / x) * x)");
// 64x/32 + y + 64/32
auto u2 = AutoSimplify(Expr(64) * x / Expr(32) + y + Expr(64) / Expr(32));
ASSERT_EQ(GetStreamCnt(u2), "(2 + ((2 * x) + y))");
// 1/32 * y * z * 32768 * 2
auto u3 = AutoSimplify(Expr(1) / Expr(32) * y * z * 32768 * 2);
EXPECT_EQ(GetStreamCnt(u3), "0");
// 32768 * (32x + y) + y
auto u4 = AutoSimplify(Expr(32768) * (((Expr(32) * x) + y) / 32));
EXPECT_EQ(GetStreamCnt(u4), "((32768 * (y / 32)) + (32768 * x))");
common::cas_intervals_t var_intervals;
var_intervals.emplace("y", common::CasInterval(0, 31));
auto u = AutoSimplify((Expr(x) * 32 + y) / 32, var_intervals);
EXPECT_EQ(GetStreamCnt(u), "x");
u = AutoSimplify((Expr(x) * 33 + y) / 32, var_intervals);
EXPECT_EQ(GetStreamCnt(u), "(((33 * x) + y) / 32)");
u = AutoSimplify(Expr(125) / 8 - 1);
EXPECT_EQ(GetStreamCnt(u), "14");
}
#define OUTPUT_EQUAL(s__) EXPECT_EQ(GetStreamCnt(u), s__);
TEST(CAS, Mod) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
Var k = ir::_Var_::Make("k", Int(32));
absl::flat_hash_map<std::string, CasInterval> var_intervals0, var_intervals1;
var_intervals0.emplace("x", CasInterval{0, 3});
var_intervals0.emplace("y", CasInterval{0, 3});
var_intervals0.emplace("z", CasInterval{0, 3});
var_intervals0.emplace("k", CasInterval{0, 3});
Expr u;
u = AutoSimplify(x % 5);
EXPECT_EQ(GetStreamCnt(u), "(x % 5)");
OUTPUT_EQUAL("(x % 5)")
u = AutoSimplify((5 + x) % 5);
OUTPUT_EQUAL("(x % 5)")
u = AutoSimplify((x + 5 * y + 1 + 1 + 3 - z * 3) % 5);
OUTPUT_EQUAL("((x + (-3 * z)) % 5)")
// u = AutoSimplify((x + 5) % 5, var_intervals0);
// OUTPUT_EQUAL("x")
// u = AutoSimplify((x + y + 5) % 5, var_intervals0);
// OUTPUT_EQUAL("((x + y) % 5)")
// u = AutoSimplify((x + 20 * y + 5) % 5, var_intervals0);
// OUTPUT_EQUAL("x")
u = AutoSimplify(
(x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))));
OUTPUT_EQUAL(
"((32768 * (x / 32)) + ((x % 32) + ((128 * k) + ((32768 * y) + (32 * "
"z)))))");
u = AutoSimplify(
(x % 32) + ((32768 * (x / 32)) + ((32768 * y) + ((32 * z) + (128 * k)))),
var_intervals0);
OUTPUT_EQUAL("((128 * k) + (x + ((32768 * y) + (32 * z))))")
// (2x+y+z) % 2 = (y+z) % 2
u = AutoSimplify((2 * x + y + z) % 2, var_intervals0);
OUTPUT_EQUAL("((y + z) % 2)")
// 0 % x = 0
u = AutoSimplify(0 % x);
OUTPUT_EQUAL("0")
// 1 % x = 1
u = AutoSimplify(1 % x);
OUTPUT_EQUAL("1")
// (x * 6) % 2 = 0
u = AutoSimplify((x * 6) % 2);
OUTPUT_EQUAL("0")
// (x * 2) % 6 = (x % 3) * 2
u = AutoSimplify((x * 2) % 6);
OUTPUT_EQUAL("((x % 3) * 2)")
// 7 % 3 = 1
u = AutoSimplify(Expr(7) % Expr(3));
OUTPUT_EQUAL("1")
// x % 1 = 0
u = AutoSimplify(x % 1);
OUTPUT_EQUAL("0")
// (m / n) * n + m % n = m (m, n's type is int)
u = AutoSimplify((x / 10) * 10 + x % 10);
OUTPUT_EQUAL("x")
u = AutoSimplify(((x + y * 2) / 10) * 10 + (x + y * 2) % 10 + 3 * z);
OUTPUT_EQUAL("(x + ((2 * y) + (3 * z)))")
}
TEST(CAS, IntConnerCase) {
Var x = ir::_Var_::Make("x", Int(32));
Var y = ir::_Var_::Make("y", Int(32));
Var z = ir::_Var_::Make("z", Int(32));
auto u1 = AutoSimplify(Expr(1) / 32);
EXPECT_EQ(GetStreamCnt(u1), "0");
auto u2 = AutoSimplify(x / 32 + (x * 32 + 64) / 32);
EXPECT_EQ(GetStreamCnt(u2), "((x / 32) + (2 + x))");
// (32x+y)/32 * 1024 * 32
auto u3 = AutoSimplify((((((32 * x) + y) / 32) * 1024) * 32));
EXPECT_EQ(GetStreamCnt(u3), "((32768 * (y / 32)) + (32768 * x))");
auto u4 = AutoSimplify(Expr(1) / 3);
EXPECT_EQ(GetStreamCnt(u4), "0");
absl::flat_hash_map<std::string, CasInterval> var_intervals0, var_intervals1;
var_intervals0.emplace("y", CasInterval{2, 3});
var_intervals1.emplace("y", CasInterval{0, 3});
auto u5 = AutoSimplify(Expr(1) / y, var_intervals0);
EXPECT_EQ(GetStreamCnt(u5), "0");
auto u6 = AutoSimplify(y / 4, var_intervals0);
EXPECT_EQ(GetStreamCnt(u6), "0");
auto u7 = AutoSimplify(1 / y, var_intervals1);
EXPECT_EQ(GetStreamCnt(u7), "(1 / y)");
auto u8 = AutoSimplify(-1 / y, var_intervals1);
EXPECT_EQ(GetStreamCnt(u8), "(-1 / y)");
}
TEST(SolveInequality, basic) {
Var x("x", Int(32));
Var y("y", Int(32));
#define TEST_SOLVE(expr__, str__) \
EXPECT_EQ(GetStreamCnt(SolveInequality(expr__, x)), str__);
TEST_SOLVE(x * -1 + 20 < 0, "(x > 20)");
TEST_SOLVE(x * 2 + 3 < x * 10 - 20, "(x > 2)");
TEST_SOLVE(x * -1 < -1, "(x > 1)");
TEST_SOLVE(Expr(2) * x * -1 - x < x + 200, "(x > -50)");
TEST_SOLVE(Expr(2) * x + 30 - x * 3 + y * 23 < 2,
"(x > int32((28 + (23 * y))))");
TEST_SOLVE(x + ir::Min::Make(Expr(2), Expr(3) * y) < 100,
"(x < int32(cinn_max((100 + (-3 * y)), 98)))");
}
TEST(CAS, SimplifyCompoundMod) {
{ // (-a % 4) * (-1)
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::Product::Make({ir::Mod::Make(-x, Expr(4)), Expr(-1)});
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "(-1 * ((-1 * x) % 4))");
}
{ // (33 + x % 34) + -33
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::Sum::Make(
{Expr(33), ir::Sum::Make({ir::Mod::Make(x, Expr(4)), Expr(-33)})});
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "(x % 4)");
}
{ // 33 + (x % 2 + (-16))
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::Sum::Make(
{Expr(33),
ir::Sum::Make({ir::Mod::Make(x, Expr(2)),
ir::Product::Make({Expr(-1), Expr(16)})})});
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "(17 + (x % 2))");
}
{ // (32- x1 - 16 * x2) % 33
Var x1 = ir::_Var_::Make("x1", Int(32));
Var x2 = ir::_Var_::Make("x2", Int(32));
auto p0 =
ir::Mod::Make(ir::Sum::Make({Expr(32), -x1, Expr(16) * -x2}), Expr(33));
LOG(INFO) << "p0 " << p0;
absl::flat_hash_map<std::string, CasInterval> var_intervals;
var_intervals.emplace("x1", CasInterval{0, 15});
var_intervals.emplace("x2", CasInterval{0, 1});
auto p2 = AutoSimplify(p0, var_intervals);
LOG(INFO) << "simplified " << p2;
#ifdef CINN_WITH_CUDA
EXPECT_EQ(GetStreamCnt(p2), "((32 + ((-1 * x1) + (-16 * x2))) % 33)");
#else
EXPECT_EQ(GetStreamCnt(p2), "(32 + (((-1 * x1) + (-16 * x2)) % 33))");
#endif
}
}
TEST(CAS, SimplifyNegtive) {
{ // (-1*x) /2
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::FracOp::Make(-x, Expr(2));
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "((-1 * x) / 2)");
}
{ // minus(1)
auto p0 = ir::Minus::Make(Expr(1));
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "-1");
}
}
TEST(CAS, SimplifyMinMax) {
{ // 1+cinn_min(15, x)
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::Sum::Make({Expr(1), ir::Min::Make(Expr(15), x)});
LOG(INFO) << "p0 " << p0;
auto p2 = CasSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(16, (1 + x))");
}
{ // 2*cinn_min(15, x)
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::Product::Make({Expr(2), ir::Min::Make(Expr(15), x)});
LOG(INFO) << "p0 " << p0;
auto p2 = CasSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(30, (2 * x))");
}
{ // cinn_min(15, x)/2
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::FracOp::Make(ir::Min::Make(Expr(15), x), Expr(2));
LOG(INFO) << "p0 " << p0;
auto p2 = CasSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, (x / 2))");
}
{ // -(cinn_min(16, 3400-x-1)-1)/2 + x
Var x = ir::_Var_::Make("x", Int(32));
auto p0 =
ir::FracOp::Make(ir::Min::Make(Expr(16), 3400 - x - 1) - 1, Expr(2));
p0 = -p0 + x;
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2),
"cinn_max((-1699 + ((-1 * ((-1 * x) / 2)) + x)), (-7 + x))");
}
{ // cinn_max((-1 * (3399 + (-16 * i_j_fused_outer))), -15)
Var x = ir::_Var_::Make("x", Int(32));
auto p0 = ir::Max::Make(
ir::Product::Make(
{Expr(-1), ir::Sum::Make({Expr(3399), Expr(-16) * x})}),
Expr(-15));
LOG(INFO) << "p0 " << p0;
auto p2 = AutoSimplify(p0);
LOG(INFO) << "simplified " << p2;
EXPECT_EQ(GetStreamCnt(p2), "cinn_max((-3399 + (16 * x)), -15)");
}
}
TEST(CAS, cond) {
{
Expr cond = Expr(2) > Expr(1);
EXPECT_EQ(GetStreamCnt(CasSimplify(cond)), "true");
}
{
Var a("a");
Expr cond = (Expr(2) > Expr(1)) && (a < 20);
EXPECT_EQ(GetStreamCnt(CasSimplify(cond)), "(a < 20)");
}
{
Var a("a");
Expr cond = (Expr(2) < Expr(1)) && (a < 20);
EXPECT_EQ(GetStreamCnt(CasSimplify(cond)), "false");
}
}
TEST(CAS, SimplifyFracOp) {
Expr frac = Expr(1) / Expr(7) / Expr(6) / Expr(5) / Expr(4);
EXPECT_EQ(GetStreamCnt(AutoSimplify(frac)), "0");
Expr frac_f = Expr(20.0f) / Expr(2.0f) / Expr(1.0f) / Expr(5.0f);
EXPECT_EQ(GetStreamCnt(AutoSimplify(frac_f)), "2.00000000f");
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/common/cinn_value.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
namespace cinn {
namespace ir {
class Expr;
class Var;
} // namespace ir
namespace common {
//! Implement the type_code for all the supported types.
// @{
#define __m(T, code__) \
template <> \
int CINNValue::TypeCode<T>() { \
return code__; \
}
__m(std::nullptr_t, -1);
__m(char *, 20); // start from a larger number to avoid duplicate id with
// cinn_pod_value_t
__m(char const *, 21);
__m(ir::Expr, 22);
__m(ir::Var, 23);
__m(CINNValuePack, 24);
__m(poly::StageMap, 25);
__m(std::string, 26);
#undef __m
//@}
//! Implement ToValue.
// @{
template <>
cinn_value_t ToValue<bool>(bool v) {
cinn_value_t val;
val.v_int64 = v;
return val;
}
template <>
cinn_value_t ToValue<int>(int v) {
cinn_value_t val;
val.v_int64 = v;
return val;
}
template <>
cinn_value_t ToValue<int64_t>(int64_t v) {
cinn_value_t val;
val.v_int64 = v;
return val;
}
template <>
cinn_value_t ToValue<float>(float v) {
cinn_value_t val;
val.v_float64 = v;
return val;
}
template <>
cinn_value_t ToValue<double>(double v) {
cinn_value_t val;
val.v_float64 = v;
return val;
}
template <>
cinn_value_t ToValue<bfloat16>(bfloat16 v) {
cinn_value_t val;
val.v_float64 = static_cast<double>(v);
return val;
}
template <>
cinn_value_t ToValue<float16>(float16 v) {
cinn_value_t val;
val.v_float64 = static_cast<double>(v);
return val;
}
template <>
cinn_value_t ToValue<char *>(char *v) {
cinn_value_t val;
val.v_str = v;
return val;
}
template <>
cinn_value_t ToValue<char const *>(char const *v) {
cinn_value_t val;
val.v_str = const_cast<char *>(v);
return val;
}
// @}
bool CINNValue::is_string() const {
return type_code_ == TypeCode<std::string>();
}
bool CINNValue::is_var() const { return type_code_ == TypeCode<ir::Var>(); }
bool CINNValue::is_expr() const {
return type_code_ == TypeCode<ir::Expr>() &&
!absl::any_cast<Expr>(shared_).as_tensor();
}
bool CINNValue::is_stagemap() const {
return type_code_ == TypeCode<poly::StageMap>();
}
bool CINNValue::is_tensor() const {
return type_code_ == TypeCode<ir::Expr>() &&
absl::any_cast<Expr>(shared_).as_tensor();
}
CINNValue::operator std::string() const {
CHECK_EQ(type_code_, TypeCode<std::string>());
return absl::any_cast<std::string>(shared_);
}
CINNValue::operator ir::Var() const {
CHECK_EQ(type_code_, TypeCode<ir::Var>());
return absl::any_cast<ir::Var>(shared_);
}
CINNValue::operator ir::Expr() const {
CHECK_EQ(type_code_, TypeCode<ir::Expr>());
return absl::any_cast<Expr>(shared_);
}
CINNValue::operator CINNValuePack() const {
CHECK_EQ(type_code_, TypeCode<CINNValuePack>());
return absl::any_cast<CINNValuePack>(shared_);
}
CINNValue::operator poly::StageMap() const {
CHECK_EQ(type_code(), TypeCode<poly::StageMap>());
return absl::any_cast<poly::StageMap>(shared_);
}
CINNValue::CINNValue(char *value)
: cinn_pod_value_t(ToValue(value), TypeCode<char *>()) {}
CINNValue::CINNValue(const std::string &value)
: cinn_pod_value_t(cinn_value_t(), TypeCode<std::string>()) {
shared_ = value;
}
CINNValue::CINNValue(const Var &value)
: cinn_pod_value_t(cinn_value_t(), TypeCode<Var>()) {
CHECK(value.defined());
shared_ = value;
}
CINNValue::CINNValue(const Expr &value)
: cinn_pod_value_t(cinn_value_t(), TypeCode<Expr>()) {
CHECK(value.defined());
shared_ = value;
}
CINNValue::CINNValue(const CINNValuePack &value)
: cinn_pod_value_t(cinn_value_t(), TypeCode<CINNValuePack>()) {
CHECK(value.defined());
shared_ = value;
}
CINNValue::CINNValue(const poly::StageMap &value)
: cinn_pod_value_t(cinn_value_t(), TypeCode<poly::StageMap>()) {
CHECK(value.defined());
shared_ = value;
}
CINNValuePack _CINNValuePack_::Make(const std::vector<CINNValue> &array) {
auto *node = new _CINNValuePack_;
for (auto &item : array) node->AddValue(item);
return CINNValuePack(node);
}
CINNValue &_CINNValuePack_::operator[](int offset) {
CHECK_LT(offset, size());
return values_[offset];
}
const CINNValue &_CINNValuePack_::operator[](int offset) const {
CHECK_LT(offset, size());
return values_[offset];
}
void _CINNValuePack_::AddValue(const CINNValue &value) {
CHECK(value.defined());
values_.push_back(value);
}
void _CINNValuePack_::Clear() { values_.clear(); }
const char *_CINNValuePack_::type_info() const { return __type_info__; }
CINNValue &CINNValue::operator=(bool value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(int32_t value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(int64_t value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(float value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(double value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(bfloat16 value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(float16 value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(char *value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(cinn_buffer_t *value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(void *value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(const char *value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(const CINNValuePack &value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(const std::string &value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(const ir::Var &value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(const ir::Expr &value) {
*this = CINNValue(value);
return *this;
}
CINNValue &CINNValue::operator=(const poly::StageMap &value) {
*this = CINNValue(value);
return *this;
}
} // namespace common
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/types/any.h>
#include <glog/logging.h>
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/object.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
struct cinn_buffer_t;
namespace cinn {
namespace poly {
struct StageMap;
} // namespace poly
namespace ir {
class Expr;
class Var;
} // namespace ir
namespace common {
template <typename T>
cinn_value_t ToValue(T v);
class CINNValue;
class CINNValuePack;
/**
* A _CINNValuePack_ is a shared Array of multiple CINNValue.
*/
struct _CINNValuePack_ : public common::Object {
/**
* Create a new CINNValuePack instance.
* @param array The list of CINNValues.
* @return a CINNValuePack.
*/
static CINNValuePack Make(const std::vector<CINNValue>& array);
//! Get i-th element in mutable mode.
CINNValue& operator[](int offset);
//! Get i-th element in readonly mode.
const CINNValue& operator[](int offset) const;
//! Add one \p value to the tail.
void AddValue(const CINNValue& value);
//! Remove all the values.
void Clear();
size_t size() const { return values_.size(); }
bool empty() const { return values_.empty(); }
CINN_DISALLOW_COPY_AND_ASSIGN(_CINNValuePack_);
const char* type_info() const override;
private:
_CINNValuePack_() = default;
std::vector<CINNValue> values_;
static constexpr char* __type_info__ = "CINNValuePack";
};
struct CINNValuePack : public Shared<_CINNValuePack_> {
explicit CINNValuePack(_CINNValuePack_* ptr) : Shared<_CINNValuePack_>(ptr) {}
explicit CINNValuePack(const std::vector<CINNValue>& array)
: Shared<_CINNValuePack_>(_CINNValuePack_::Make(array)) {}
CINNValue& operator[](int offset) { return (*operator->())[offset]; }
const CINNValue& operator[](int offset) const {
return (*operator->())[offset];
}
size_t size() const { return (*operator->()).size(); }
bool empty() const { return (*operator->()).empty(); }
CINNValue& back() {
CHECK_GT((*operator->()).size(), 0);
return (*operator->())[size() - 1];
}
const CINNValue& back() const {
CHECK_GT((*operator->()).size(), 0);
return (*operator->())[size() - 1];
}
_CINNValuePack_* operator->() { return get(); }
const _CINNValuePack_* operator->() const { return get(); }
};
/**
* Handler for value types in CINN system. It supports two kinds of values: the
* POD and Shared.
*/
class CINNValue : public cinn_pod_value_t {
public:
static constexpr int kNull = -1;
CINNValue() : cinn_pod_value_t(cinn_value_t(), kNull) {}
CINNValue(cinn_value_t value, int type_code)
: cinn_pod_value_t(value, type_code) {}
explicit CINNValue(bool value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<bool>();
}
explicit CINNValue(int32_t value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<int32_t>();
}
explicit CINNValue(int64_t value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<int64_t>();
}
explicit CINNValue(float value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<float>();
}
explicit CINNValue(bfloat16 value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<bfloat16>();
}
explicit CINNValue(float16 value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<float16>();
}
explicit CINNValue(double value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<double>();
}
explicit CINNValue(char* value);
explicit CINNValue(cinn_buffer_t* value) : cinn_pod_value_t(value) {}
explicit CINNValue(void* value) : cinn_pod_value_t(value) {}
explicit CINNValue(const char* value) : cinn_pod_value_t(value) {}
explicit CINNValue(const std::string&);
explicit CINNValue(const ir::Var& value);
explicit CINNValue(const ir::Expr& value);
explicit CINNValue(const CINNValuePack& value);
explicit CINNValue(const poly::StageMap& value);
bool defined() const { return type_code_ != kNull; }
//! The value getters for the supported types.
// @{
using cinn_pod_value_t::operator double;
using cinn_pod_value_t::operator float;
using cinn_pod_value_t::operator cinn::common::bfloat16;
using cinn_pod_value_t::operator cinn::common::float16;
using cinn_pod_value_t::operator bool;
using cinn_pod_value_t::operator int32_t;
using cinn_pod_value_t::operator int64_t;
using cinn_pod_value_t::operator void*;
using cinn_pod_value_t::operator cinn_buffer_t*;
using cinn_pod_value_t::operator char*;
operator std::string() const;
operator ir::Var() const;
operator ir::Expr() const;
operator CINNValuePack() const;
operator poly::StageMap() const;
// @}
bool is_string() const;
bool is_var() const;
bool is_expr() const;
bool is_stagemap() const;
bool is_tensor() const;
//! Assign operators
// @{
CINNValue& operator=(bool value);
CINNValue& operator=(int32_t value);
CINNValue& operator=(int64_t value);
CINNValue& operator=(float value);
CINNValue& operator=(double value);
CINNValue& operator=(bfloat16 value);
CINNValue& operator=(float16 value);
CINNValue& operator=(char* value);
CINNValue& operator=(const std::string& value);
CINNValue& operator=(const ir::Var& value);
CINNValue& operator=(const ir::Expr& value);
CINNValue& operator=(cinn_buffer_t* value);
CINNValue& operator=(void* value);
CINNValue& operator=(const CINNValuePack& value);
CINNValue& operator=(const char* value);
CINNValue& operator=(const poly::StageMap& value);
// @}
// //! Set the value.
// template <typename T>
// void Set(T v) {
// if constexpr (std::is_same_v<std::decay_t<T>, CINNValue>) {
// *this = v;
// } else {
// *this = CINNValue(v);
// }
// }
template <typename T>
inline void _Set(T v, std::true_type) {
*this = v;
}
template <typename T>
inline void _Set(T v, std::false_type) {
*this = CINNValue(v);
}
// using tag-dispatch instead of constexpr if
template <typename T>
void Set(T v) {
_Set(v, std::is_same<std::decay_t<T>, CINNValue>{});
}
/**
* Get the type code for a specific POD type.
* @param T some data type.
* @return an integer representing the type code.
*/
template <typename T>
static int TypeCode();
protected:
absl::any shared_;
};
} // namespace common
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment