Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include <absl/strings/string_view.h>
#include <llvm/ADT/Triple.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/Config/llvm-config.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/InitializePasses.h>
#include <llvm/PassRegistry.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/Host.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/NewGVN.h>
#include <llvm/Transforms/Scalar/Reassociate.h>
#include <llvm/Transforms/Scalar/SimplifyCFG.h>
#include <cmath>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <utility>
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/profiler.h"
namespace cinn::backends {
namespace {
void InitializeLLVMPasses() {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
auto &registry = *llvm::PassRegistry::getPassRegistry();
llvm::initializeCore(registry);
llvm::initializeTransformUtils(registry);
llvm::initializeScalarOpts(registry);
llvm::initializeIPO(registry);
llvm::initializeInstCombine(registry);
llvm::initializeAggressiveInstCombine(registry);
llvm::initializeAnalysis(registry);
llvm::initializeVectorization(registry);
llvm::initializeSROALegacyPassPass(registry);
// llvm::initializeCodeGen(registry);
// llvm::initializeTarget(registry);
// llvm::initializeCodeGenPreparePass(registry);
}
} // namespace
void NaiveObjectCache::notifyObjectCompiled(const llvm::Module *m,
llvm::MemoryBufferRef obj_buffer) {
cached_objects_[m->getModuleIdentifier()] =
llvm::MemoryBuffer::getMemBufferCopy(obj_buffer.getBuffer(),
obj_buffer.getBufferIdentifier());
}
std::unique_ptr<llvm::MemoryBuffer> NaiveObjectCache::getObject(
const llvm::Module *m) {
auto it = cached_objects_.find(m->getModuleIdentifier());
if (it == cached_objects_.end()) {
VLOG(1) << "No object for " << m->getModuleIdentifier()
<< " in cache. Compiling.";
return nullptr;
}
VLOG(3) << "Object for " << m->getModuleIdentifier() << " loaded from cache.";
return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef());
}
/*static*/ std::unique_ptr<ExecutionEngine> ExecutionEngine::Create(
const ExecutionOptions &config) {
return Create(config, {});
}
/*static*/ std::unique_ptr<ExecutionEngine> ExecutionEngine::Create(
const ExecutionOptions &config, RuntimeSymbols &&module_symbols) {
VLOG(1) << "===================== Create CINN ExecutionEngine begin "
"====================";
VLOG(1) << "initialize llvm config";
VLOG(1) << "llvm version: " << LLVM_VERSION_STRING;
VLOG(1) << "llvm default target triple: " << LLVM_DEFAULT_TARGET_TRIPLE;
static std::once_flag flag;
std::call_once(flag, InitializeLLVMPasses);
auto engine = std::make_unique<ExecutionEngine>(/*enable_object_cache=*/true,
std::move(module_symbols));
auto compile_layer_creator =
[&engine](llvm::orc::JITTargetMachineBuilder jtmb)
-> llvm::Expected<
std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
auto machine = llvm::cantFail(jtmb.createTargetMachine());
VLOG(1) << "create llvm compile layer";
VLOG(1) << "Target Name: " << machine->getTarget().getName();
VLOG(1) << "Target CPU: " << machine->getTargetCPU().str() << std::endl;
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(
std::move(machine), engine->cache_.get());
};
auto object_layer_creator = [&](llvm::orc::ExecutionSession &session,
const llvm::Triple &triple) {
auto object_layer = std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(
session,
[]() { return std::make_unique<llvm::SectionMemoryManager>(); });
llvm::orc::JITDylib *main_jd = session.getJITDylibByName("<main>");
if (!main_jd) {
main_jd = &llvm::cantFail(session.createJITDylib("<main>"));
}
return object_layer;
};
VLOG(2) << "create jit execution engine";
engine->jit_ =
llvm::cantFail(llvm::orc::LLJITBuilder()
.setCompileFunctionCreator(compile_layer_creator)
.setObjectLinkingLayerCreator(object_layer_creator)
.create());
engine->jit_->getMainJITDylib().addGenerator(llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
engine->jit_->getDataLayout().getGlobalPrefix())));
VLOG(2) << "register runtime call symbols";
engine->RegisterRuntimeSymbols();
VLOG(2) << "===================== Create CINN ExecutionEngine end "
"====================";
return engine;
}
template <typename CodeGenT>
void ExecutionEngine::Link(const ir::Module &module) {
utils::RecordEvent("ExecutionEngine Link", utils::EventType::kOrdinary);
llvm::SMDiagnostic error;
auto ctx = std::make_unique<llvm::LLVMContext>();
auto m = llvm::parseAssemblyString(
AsStringRef(backends::kRuntimeLlvmIr), error, *ctx);
auto b = std::make_unique<llvm::IRBuilder<>>(*ctx);
auto ir_emitter = std::make_unique<CodeGenT>(m.get(), b.get());
VLOG(3) << "ir_emitter->Compile(module) Begin";
ir_emitter->Compile(module);
VLOG(3) << "ir_emitter->Compile(module) Succeed!";
CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";
auto machine = std::move(llvm::cantFail(
llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost())
.createTargetMachine()));
LLVMModuleOptimizer optimize(machine.get(), 3, {}, true);
optimize(m.get());
CHECK(!llvm::verifyModule(*m, &llvm::errs()))
<< "Invalid optimized module detected";
for (auto &f : *m) {
VLOG(5) << "function: " << DumpToString(f);
}
llvm::raw_svector_ostream rawstream(buffer_);
llvm::legacy::PassManager pass_manager;
machine->addPassesToEmitFile(
pass_manager, rawstream, nullptr, llvm::CGFT_ObjectFile);
pass_manager.run(*m);
CHECK(AddModule(std::move(m), std::move(ctx)));
if (VLOG_IS_ON(5)) {
VLOG(5) << "======= dump jit execution session ======";
std::string buffer;
llvm::raw_string_ostream os(buffer);
decltype(auto) es = jit_->getExecutionSession();
es.dump(os);
os.flush();
VLOG(5) << buffer;
}
}
bool ExecutionEngine::AddModule(std::unique_ptr<llvm::Module> module,
std::unique_ptr<llvm::LLVMContext> context) {
utils::RecordEvent("ExecutionEngine AddModule", utils::EventType::kOrdinary);
module->setDataLayout(jit_->getDataLayout());
if (VLOG_IS_ON(5)) {
VLOG(5) << "======= dump jit lib ==========";
std::string buffer;
llvm::raw_string_ostream os(buffer);
module->print(os, {});
// main_jd_->dump(os);
os.flush();
VLOG(5) << buffer;
}
llvm::orc::ThreadSafeContext tsc(std::move(context));
llvm::orc::ThreadSafeModule tsm(std::move(module), std::move(tsc));
llvm::cantFail(jit_->addIRModule(std::move(tsm)));
return true;
}
void ExecutionEngine::ExportObject(const std::string &path) {
FILE *of = fopen(path.c_str(), "w");
fwrite(buffer_.data(), 1, buffer_.size(), of);
fclose(of);
}
void *ExecutionEngine::Lookup(absl::string_view name) {
utils::RecordEvent("ExecutionEngine Lookup", utils::EventType::kOrdinary);
std::lock_guard<std::mutex> lock(mu_);
if (auto symbol = jit_->lookup(AsStringRef(name))) {
return reinterpret_cast<void *>(symbol->getAddress());
}
LOG(ERROR) << "Unknown symbol name[" << name << "]";
return nullptr;
}
void ExecutionEngine::RegisterRuntimeSymbols() {
utils::RecordEvent("ExecutionEngine RegisterRuntimeSymbols",
utils::EventType::kOrdinary);
const auto &registry = GlobalSymbolRegistry::Global();
auto *session = &jit_->getExecutionSession();
for (const auto &sym : registry.All()) {
llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols(
{{session->intern(sym.first),
{llvm::pointerToJITTargetAddress(sym.second),
llvm::JITSymbolFlags::None}}})));
}
for (const auto &sym : module_symbols_.All()) {
llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols(
{{session->intern(sym.first),
{llvm::pointerToJITTargetAddress(sym.second),
llvm::JITSymbolFlags::None}}})));
}
}
template void ExecutionEngine::Link<CodeGenLLVM>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenX86>(const ir::Module &module);
template void ExecutionEngine::Link<CodeGenCUDA_Host>(const ir::Module &module);
} // namespace cinn::backends
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/StringMap.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/ObjectCache.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <optional>
#include <string>
#include <vector>
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/module.h"
namespace cinn::backends {
class NaiveObjectCache : public llvm::ObjectCache {
public:
void notifyObjectCompiled(const llvm::Module *,
llvm::MemoryBufferRef) override;
std::unique_ptr<llvm::MemoryBuffer> getObject(const llvm::Module *) override;
private:
llvm::StringMap<std::unique_ptr<llvm::MemoryBuffer>> cached_objects_;
};
struct ExecutionOptions {
int opt_level{3};
bool enable_debug_info{false};
// TODO(fc500110)
// int num_compile_threads{1};
// bool enable_fast_math;
};
class ExecutionEngine {
public:
static std::unique_ptr<ExecutionEngine> Create(
const ExecutionOptions &config);
static std::unique_ptr<ExecutionEngine> Create(
const ExecutionOptions &config, RuntimeSymbols &&module_symbols);
void *Lookup(absl::string_view name);
template <typename CodeGenT = CodeGenLLVM>
void Link(const ir::Module &module);
void ExportObject(const std::string &path);
bool AddModule(std::unique_ptr<llvm::Module> module,
std::unique_ptr<llvm::LLVMContext> context);
protected:
explicit ExecutionEngine(bool enable_object_cache,
RuntimeSymbols &&module_symbols)
: cache_(std::make_unique<NaiveObjectCache>()),
module_symbols_(std::move(module_symbols)) {}
void RegisterRuntimeSymbols();
bool SetupTargetTriple(llvm::Module *module);
// This may not be a compatible implementation.
friend std::unique_ptr<ExecutionEngine> std::make_unique<ExecutionEngine>(
bool &&, cinn::backends::RuntimeSymbols &&);
private:
mutable std::mutex mu_;
llvm::SmallString<0> buffer_;
std::unique_ptr<llvm::orc::LLJIT> jit_;
std::unique_ptr<NaiveObjectCache> cache_;
RuntimeSymbols module_symbols_;
};
} // namespace cinn::backends
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include <glog/logging.h>
#include <glog/raw_logging.h>
#include <gtest/gtest.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/raw_ostream.h>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <memory>
#include <random>
#include <tuple>
#include <utility>
#include <vector>
#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/runtime/cpu/host_intrinsics.h"
#include "paddle/cinn/runtime/cpu/use_extern_funcs.h"
namespace cinn {
namespace backends {
namespace {
bool RegisterKnownSymbols() {
decltype(auto) registry = GlobalSymbolRegistry::Global();
registry.RegisterFn("sinf", reinterpret_cast<void *>(&sinf));
registry.RegisterFn(
"sin", reinterpret_cast<void *>(static_cast<double (*)(double)>(&sin)));
registry.RegisterFn("cosf", reinterpret_cast<void *>(&cosf));
registry.RegisterFn(
"cos", reinterpret_cast<void *>(static_cast<double (*)(double)>(&cos)));
return true;
}
[[maybe_unused]] bool unused = RegisterKnownSymbols();
constexpr int kM = 100;
constexpr int kN = 32;
auto CreateTestBuffer() {
auto *A = cinn_buffer_t::new_(
cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32);
auto *B = cinn_buffer_t::new_(
cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32);
auto *C = cinn_buffer_t::new_(
cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32);
cinn_buffer_malloc(nullptr, A);
cinn_buffer_malloc(nullptr, B);
cinn_buffer_malloc(nullptr, C);
float *Ad = reinterpret_cast<float *>(A->memory);
float *Bd = reinterpret_cast<float *>(B->memory);
for (int i = 0; i < A->num_elements(); i++) {
Ad[i] = static_cast<float>(rand()) / RAND_MAX; // NOLINT
Bd[i] = static_cast<float>(rand()) / RAND_MAX; // NOLINT
}
float *Cd = reinterpret_cast<float *>(C->memory);
CHECK_EQ(C->num_elements(), A->num_elements());
return std::make_tuple(A, B, C);
}
auto CreateTestCinnModule() {
ir::Expr M(kM);
ir::Expr N(kN);
lang::Placeholder<float> A("A", {M, N});
lang::Placeholder<float> B("B", {M, N});
lang::Buffer C_buf(Float(32));
auto C = lang::Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C");
C->Bind(C_buf);
common::Target target;
target.arch = common::Target::Arch::X86;
target.bits = common::Target::Bit::k32;
target.os = common::Target::OS::Linux;
ir::Module::Builder builder("module1", target);
auto stages = CreateStages({C});
auto funcs = lang::Lower("elementwise_add", stages, {A, B, C});
// auto func = optim::Optimize(funcs);
builder.AddFunction(ir::LoweredFunc(funcs.As<ir::_LoweredFunc_>()));
return builder.Build();
}
} // namespace
TEST(llvm_test01, elementwise_add) {
return;
auto engine = backends::ExecutionEngine::Create({1});
auto _a_b_c_ = CreateTestBuffer(); // NOLINT
auto &a = std::get<0>(_a_b_c_);
auto &b = std::get<1>(_a_b_c_);
auto &c = std::get<2>(_a_b_c_);
auto module = CreateTestCinnModule();
engine->Link(module);
auto elementwise_add_addr = engine->Lookup("elementwise_add");
return;
auto elementwise_add =
reinterpret_cast<void (*)(void *, int32_t)>(elementwise_add_addr);
cinn_pod_value_t a_arg(a), b_arg(b), c_arg(c);
cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg};
elementwise_add(args, 3);
float *ad = reinterpret_cast<float *>(a->memory);
float *bd = reinterpret_cast<float *>(b->memory);
float *cd = reinterpret_cast<float *>(c->memory);
for (int i = 0; i < c->num_elements(); i++) {
EXPECT_EQ(ad[i] + bd[i], cd[i]);
}
}
TEST(llvm, module_call_lowered_func) {
ir::Module::Builder builder("some_module", common::DefaultHostTarget());
ir::Expr M(kM);
ir::Expr N(kN);
{ // define fn
lang::Placeholder<float> a("A", {M, N});
lang::Placeholder<float> b("B", {M, N});
auto c = lang::Compute(
{M, N}, [&](auto i, auto j) { return a(i, j) + b(i, j); }, "C");
auto stages = CreateStages({c});
auto fn = lang::Lower("elementwise_add", stages, {a, b, c}, {});
builder.AddFunction(fn);
}
{ // call fn
lang::Placeholder<float> a("A", {M, N});
lang::Placeholder<float> b("B", {M, N});
std::vector<lang::ReturnType> ret_types(
{lang::ReturnType{Float(32), {M, N}, "c_out"}});
auto call_outs = lang::CallLowered("elementwise_add", {a, b}, ret_types);
auto c = call_outs[0];
// here we must call the output, so that it cal output something.
auto stages = CreateStages({c});
auto main_fn = lang::Lower("main", stages, {a, b, c}, {});
builder.AddFunction(main_fn);
CodeGenC codegen(common::DefaultHostTarget());
codegen.SetInlineBuiltinCodes(false);
LOG(INFO) << "module:\n"
<< codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl);
}
auto _ab_bb_cb_ = CreateTestBuffer(); // NOLINT
auto &ab = std::get<0>(_ab_bb_cb_);
auto &bb = std::get<1>(_ab_bb_cb_);
auto &cb = std::get<2>(_ab_bb_cb_);
do { // call the function
auto engine = backends::ExecutionEngine::Create({1});
LOG(INFO) << "JIT Link the module";
engine->Link(builder.Build());
auto cos_fn = (double (*)(double))engine->Lookup("cos");
LOG(INFO) << "=> LLVM JIT cos(0) = " << cos_fn(0);
auto elementwise_add_addr = engine->Lookup("elementwise_add");
auto elementwise_add =
reinterpret_cast<void (*)(void *, int32_t)>(elementwise_add_addr);
LOG(INFO) << "JIT get elementwise_add_addr";
break;
cinn_pod_value_t a_arg(ab), b_arg(bb), c_arg(cb);
cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg};
elementwise_add(args, 3);
auto *ad = reinterpret_cast<float *>(ab->memory);
auto *bd = reinterpret_cast<float *>(bb->memory);
for (int i = 0; i < kM; i++) {
for (int j = 0; j < kN; j++) {
auto *data = reinterpret_cast<float *>(cb->memory);
ASSERT_NEAR(data[i * kN + j], ad[i * kN + j] + bd[i * kN + j], 1e-5);
}
}
} while (false);
}
TEST(ExecutionEngine, custom_runtime_symbols) {
auto context = std::make_unique<llvm::LLVMContext>();
auto module =
std::make_unique<llvm::Module>("test_llvm_cpu_runtime", *context);
auto builder = std::make_unique<llvm::IRBuilder<>>(*context);
auto call_custom_target = [&](std::string name, llvm::Type *ty) {
llvm::FunctionType *fn_type = llvm::FunctionType::get(ty, {ty}, false);
llvm::Function *function =
llvm::Function::Create(fn_type,
llvm::Function::ExternalLinkage,
"_call_custom_" + name,
module.get());
function->setCallingConv(llvm::CallingConv::C);
llvm::BasicBlock *entry =
llvm::BasicBlock::Create(module->getContext(), "entry", function);
builder->SetInsertPoint(entry);
llvm::Argument *arg = &*function->args().begin();
llvm::Function *custom_function = llvm::dyn_cast<llvm::Function>(
module->getOrInsertFunction(name, fn_type).getCallee());
custom_function->setCallingConv(llvm::CallingConv::C);
llvm::Value *ret = builder->CreateCall(custom_function, {arg});
builder->CreateRet(ret);
};
llvm::Type *f32 = builder->getFloatTy();
llvm::Type *f64 = builder->getDoubleTy();
call_custom_target("cosf", f32);
call_custom_target("cos", f64);
call_custom_target("sinf", f32);
call_custom_target("sin", f64);
double pi = std::acos(-1);
std::vector<double> angle = {0., pi / 6., pi / 4., pi / 3., pi / 2., pi};
std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int> dis(-100, 100);
int random_x = dis(mt);
int random_y = dis(mt);
decltype(auto) registry = GlobalSymbolRegistry::Global();
// registry.Register("dereference_f64_ptr", (void *)+[](double *x) { return
// *x; });
for (size_t i = 0; i < angle.size(); i++) {
registry.RegisterVar("theta_" + std::to_string(i), angle[i]);
}
auto engine = cinn::backends::ExecutionEngine::Create({1});
engine->AddModule(std::move(module), std::move(context));
auto *call_cosf =
reinterpret_cast<float (*)(float)>(engine->Lookup("_call_custom_cosf"));
auto *call_cos =
reinterpret_cast<double (*)(double)>(engine->Lookup("_call_custom_cos"));
auto *call_sinf =
reinterpret_cast<float (*)(float)>(engine->Lookup("_call_custom_sinf"));
auto *call_sin =
reinterpret_cast<double (*)(double)>(engine->Lookup("_call_custom_sin"));
ASSERT_TRUE(call_cosf && call_cos && call_sinf && call_sin);
for (auto theta : angle) {
float theta_f = static_cast<float>(theta);
ASSERT_NEAR(call_cosf(theta_f), cosf(theta_f), 1e-6);
ASSERT_NEAR(call_cos(theta), cos(theta), 1e-6);
ASSERT_NEAR(call_sinf(theta_f), sinf(theta_f), 1e-6);
ASSERT_NEAR(call_sin(theta), sin(theta), 1e-6);
}
}
TEST(ExecutionEngine, call_extern) {
ir::Expr M(kM);
ir::Expr N(kN);
Placeholder<float> x("x", {M, N});
Placeholder<float> y("y", {M, N});
auto add_out = Compute(
{M, N}, [=](Var i, Var j) { return x(i, j) + y(i, j); }, "add_out");
ir::Tensor res = Compute(
{M, N},
[&](Var i, Var j) -> Expr {
return lang::CallExtern("tanh", {add_out(i, j)});
},
"res");
auto stages = CreateStages({add_out, res});
stages[add_out]->ComputeInline();
auto func = Lower("comp", stages, {x, y, res});
Module::Builder builder("module0", common::DefaultHostTarget());
builder.AddFunction(func);
auto engine = backends::ExecutionEngine::Create({1});
engine->Link(builder.Build());
auto _ab_bb_cb_ = CreateTestBuffer(); // NOLINT
auto &ab = std::get<0>(_ab_bb_cb_);
auto &bb = std::get<1>(_ab_bb_cb_);
auto &cb = std::get<2>(_ab_bb_cb_);
auto comp_addr = engine->Lookup("comp");
auto comp = reinterpret_cast<void (*)(void *, int32_t)>(comp_addr);
cinn_pod_value_t a_arg(ab), b_arg(bb), c_arg(cb);
cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg};
comp(args, 3);
auto *ad = reinterpret_cast<float *>(ab->memory);
auto *bd = reinterpret_cast<float *>(bb->memory);
auto *cd = reinterpret_cast<float *>(cb->memory);
for (int m = 0; m < kM; m++) {
for (int n = 0; n < kN; n++) {
ASSERT_NEAR(cd[m * kN + n], tanh(ad[m * kN + n] + bd[m * kN + n]), 1e-5);
}
}
}
} // namespace backends
} // namespace cinn
#!/usr/bin/env python3
# Copyright (c) 2021 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
def main():
path = sys.argv[1]
out_path = sys.argv[2]
llvm_config = sys.argv[3]
srcs = []
srcs.append('#include <absl/strings/string_view.h>')
# srcs.append('#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"\n')
srcs.append('namespace cinn::backends {')
srcs.append("static const absl::string_view kRuntimeLlvmIr(")
srcs.append('R"ROC(')
with open(path, 'r') as fr:
srcs.append(fr.read())
srcs.append(')ROC"')
srcs.append(');\n')
cmd = f"{llvm_config} --version"
version = (
subprocess.check_output(cmd, shell=True)
.decode('utf-8')
.strip()
.split('.')
)
srcs.append("struct llvm_version {")
for v, n in zip(["major", "minor", "micro"], version):
srcs.append(
" static constexpr int k{} = {};".format(
v.title(), ''.join(filter(str.isdigit, n))
)
)
srcs.append("};")
srcs.append('} // namespace cinn::backends')
with open(out_path, 'w') as fw:
fw.write("\n".join(srcs))
def get_clang_version():
pass
if __name__ == "__main__":
main()
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Value.h>
#include <utility>
namespace cinn {
namespace backends {
template <typename Derived>
class IrBuilderMixin {
protected:
template <typename... Args>
decltype(auto) BinOp(Args &&...args) {
return mixin_builder()->CreateBinOp(std::forward<Args>(args)...);
}
/// \brief +
template <typename... Args>
decltype(auto) Add(Args &&...args) {
return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FAdd(Args &&...args) {
return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) NSWAdd(Args &&...args) {
return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
}
/// \brief -
template <typename... Args>
decltype(auto) Sub(Args &&...args) {
return mixin_builder()->CreateSub(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FSub(Args &&...args) {
return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) NSWSub(Args &&...args) {
return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
}
/// \brief *
template <typename... Args>
decltype(auto) Mul(Args &&...args) {
return mixin_builder()->CreateMul(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FMul(Args &&...args) {
return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) NSWMul(Args &&...args) {
return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
}
/// \brief /
template <typename... Args>
decltype(auto) SDiv(Args &&...args) {
return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) UDiv(Args &&...args) {
return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FDiv(Args &&...args) {
return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
}
/// \brief %
template <typename... Args>
decltype(auto) SRem(Args &&...args) {
return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) URem(Args &&...args) {
return mixin_builder()->CreateURem(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FRem(Args &&...args) {
return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) And(Args &&...args) {
return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Or(Args &&...args) {
return mixin_builder()->CreateOr(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Not(Args &&...args) {
return mixin_builder()->CreateNot(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Neg(Args &&...args) {
return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FNeg(Args &&...args) {
return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpEQ(Args &&...args) {
return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpOEQ(Args &&...args) {
return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpUEQ(Args &&...args) {
return mixin_builder()->CreateFCmpUEQ(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpNE(Args &&...args) {
return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpONE(Args &&...args) {
return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpUNE(Args &&...args) {
return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpULE(Args &&...args) {
return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpOLE(Args &&...args) {
return mixin_builder()->CreateFCmpOLE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpULT(Args &&...args) {
return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpSLT(Args &&...args) {
return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpOLT(Args &&...args) {
return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpUGE(Args &&...args) {
return mixin_builder()->CreateICmpUGE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpSGE(Args &&...args) {
return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpOGE(Args &&...args) {
return mixin_builder()->CreateFCmpOGE(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpUGT(Args &&...args) {
return mixin_builder()->CreateICmpUGT(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ICmpSGT(Args &&...args) {
return mixin_builder()->CreateICmpSGT(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FCmpOGT(Args &&...args) {
return mixin_builder()->CreateFCmpOGT(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) BitCast(Args &&...args) {
return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) IntCast(Args &&...args) {
return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FPCast(Args &&...args) {
return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) PointerCast(Args &&...args) {
return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FPToSI(Args &&...args) {
return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) FPToUI(Args &&...args) {
return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) SIToFP(Args &&...args) {
return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) UIToFP(Args &&...args) {
return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Select(Args &&...args) {
return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Br(Args &&...args) {
return mixin_builder()->CreateBr(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) CondBr(Args &&...args) {
return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Alloca(Args &&...args) {
return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Load(Args &&...args) {
return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) AlignedLoad(Args &&...args) {
return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Store(Args &&...args) {
return mixin_builder()->CreateStore(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) AlignedStore(Args &&...args) {
return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) Call(Args &&...args) {
return mixin_builder()->CreateCall(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) RetVoid(Args &&...args) {
return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) GEP(Args &&...args) {
return mixin_builder()->CreateGEP(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) InBoundsGEP(Args &&...args) {
return mixin_builder()->CreateInBoundsGEP(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) PHI(Args &&...args) {
return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) InsertValue(Args &&...args) {
return mixin_builder()->CreateInsertValue(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ExtractValue(Args &&...args) {
return mixin_builder()->CreateExtractValue(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) InsertElement(Args &&...args) {
return mixin_builder()->CreateInsertElement(std::forward<Args>(args)...);
}
template <typename... Args>
decltype(auto) ShuffleVector(Args &&...args) {
return mixin_builder()->CreateShuffleVector(std::forward<Args>(args)...);
}
private:
llvm::IRBuilder<> *mixin_builder() {
return static_cast<Derived *>(this)->b();
}
};
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <glog/logging.h>
#include <llvm/IR/Intrinsics.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/registry.h"
#include "paddle/cinn/lang/packed_func.h"
namespace cinn {
namespace codegen {
template <int id, int arg_nums, bool add_float_suffix = true>
inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg = args[0];
ir::Call *node = arg->as<ir::Call>();
CHECK(node);
CHECK_GE(node->read_args.size(), arg_nums);
if (add_float_suffix) {
CHECK(node->type().is_float());
*rv = ir::intrinsics::BuiltinIntrin::Make(
node->name + "f", node->read_args, id, arg_nums, node->type());
} else {
*rv = ir::intrinsics::BuiltinIntrin::Make(
node->name, node->read_args, id, arg_nums, node->type());
}
}
void RegisterCpuIntrinRule() {
#define __(intrin_name__, id) \
ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true) \
.SetBody(MakeFloatIntrinOp<id, 1>);
__(exp, ::llvm::Intrinsic::exp)
__(exp2, ::llvm::Intrinsic::exp2)
__(sqrt, ::llvm::Intrinsic::sqrt)
__(log, ::llvm::Intrinsic::log)
__(log2, ::llvm::Intrinsic::log2)
__(log10, ::llvm::Intrinsic::log10)
__(floor, ::llvm::Intrinsic::floor)
__(ceil, ::llvm::Intrinsic::ceil)
__(round, ::llvm::Intrinsic::round)
__(trunc, ::llvm::Intrinsic::trunc)
__(cos, ::llvm::Intrinsic::cos)
__(sin, ::llvm::Intrinsic::sin)
__(fabs, ::llvm::Intrinsic::fabs)
#undef __
// set id -1 if not llvm intrinsics
#define RegisterBitwise(intrin_name__) \
ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true) \
.SetBody(MakeFloatIntrinOp<-1, 2, false>);
RegisterBitwise(bitwise_or) RegisterBitwise(bitwise_xor) RegisterBitwise(
bitwise_and) RegisterBitwise(left_shift) RegisterBitwise(right_shift)
#undef RegisterBitwise
ir::Registry::Register("lower_cpu_intrinsic_fma", true)
.SetBody(MakeFloatIntrinOp<::llvm::Intrinsic::fmuladd, 3, false>);
ir::Registry::Register("lower_cpu_intrinsic_bitwise_not", true)
.SetBody(MakeFloatIntrinOp<-1, 1, false>);
ir::Registry::Register("lower_cpu_intrinsic_isnan", true)
.SetBody(MakeFloatIntrinOp<-1, 1, false>);
ir::Registry::Register("lower_cpu_intrinsic_isfinite", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
*rv = !(lang::IsInf(arg)) && !(lang::IsNan(arg));
});
ir::Registry::Register("lower_cpu_intrinsic_isinf", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
Type type = arg->type();
if (type.is_int() || type.is_uint()) {
*rv = common::make_bool(false, type.lanes());
} else if (type.is_float()) {
*rv = ir::EQ::Make(lang::Abs(arg), lang::Infinity(type)) &&
!(lang::IsNan(arg));
}
});
ir::Registry::Register("lower_cpu_intrinsic_rsqrt", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
*rv = make_const(arg->type(), 1) / lang::Sqrt(arg);
});
ir::Registry::Register("lower_cpu_intrinsic_exp10", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
Expr ln10 = make_const(arg->type(), 2.302585093);
*rv = lang::Exp(arg * ln10);
});
ir::Registry::Register("lower_cpu_intrinsic_tan", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
*rv = lang::Sin(arg) / lang::Cos(arg);
});
ir::Registry::Register("lower_cpu_intrinsic_tanh", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
Expr zero = make_const(arg->type(), 0);
Expr one = make_const(arg->type(), 1);
Expr two = make_const(arg->type(), 2);
Expr neg_two = make_const(arg->type(), -2);
Expr exp_neg2x = lang::Exp(neg_two * arg);
Expr exp_pos2x = lang::Exp(two * arg);
Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
*rv = ir::Select::Make(arg >= zero, tanh_pos, tanh_neg);
});
ir::Registry::Register("lower_cpu_intrinsic_cosh", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
*rv = (lang::Exp(arg) + lang::Exp(arg * make_const(arg->type(), -1))) /
make_const(arg->type(), 2);
});
ir::Registry::Register("lower_cpu_intrinsic_sinh", true)
.SetBody([](lang::Args args, lang::RetValue *rv) {
CHECK_GE(args.size(), 1U);
Expr arg0 = args[0];
ir::Call *node = arg0->as<ir::Call>();
CHECK(node);
CHECK(!node->read_args.empty());
Expr arg = node->read_args[0];
*rv = (lang::Exp(arg) - lang::Exp(arg * make_const(arg->type(), -1))) /
make_const(arg->type(), 2);
});
}
} // namespace codegen
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include <glog/logging.h>
#include <llvm/ADT/Triple.h>
#include <llvm/Analysis/CGSCCPassManager.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/NewGVN.h>
#include <llvm/Transforms/Scalar/Reassociate.h>
#include <llvm/Transforms/Scalar/SimplifyCFG.h>
#include <llvm/Transforms/Vectorize.h>
#include <algorithm>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "llvm/Support/CodeGen.h"
namespace cinn::backends {
namespace {
template <typename PassManagerT>
class CustomPassManager : public PassManagerT {
public:
template <typename... Ts>
explicit CustomPassManager(bool print_passes, Ts &&...ts)
: PassManagerT(std::forward<Ts>(ts)...), print_passes_(print_passes) {}
void add(llvm::Pass *pass) override {
if (print_passes_) {
if (is_function_pass_manager_) {
VLOG(1) << "llvm run function pass[" << std::string(pass->getPassName())
<< "]";
}
if (is_module_pass_manager_) {
VLOG(1) << "llvm run module pass[" << std::string(pass->getPassName())
<< "]";
}
}
// static bool add_pass = true;
// if (add_pass) {
// PassManagerT::add(pass);
//}
// if (std::string(pass->getPassName()) == "Loop Vectorization") {
// return;
//}
PassManagerT::add(pass);
}
void run(llvm::Function &f) { // NOLINT
if (is_function_pass_manager_) {
PassManagerT::run(f);
}
}
void run(llvm::Module &m) { // NOLINT
if (is_module_pass_manager_) {
PassManagerT::run(m);
}
}
private:
static constexpr bool is_function_pass_manager_ =
std::is_same<llvm::legacy::FunctionPassManager, PassManagerT>::value;
static constexpr bool is_module_pass_manager_ =
std::is_same<llvm::legacy::PassManager, PassManagerT>::value;
bool print_passes_;
};
using CustomFunctionPassManager =
CustomPassManager<llvm::legacy::FunctionPassManager>;
using CustomModulePassManager = CustomPassManager<llvm::legacy::PassManager>;
} // namespace
LLVMModuleOptimizer::LLVMModuleOptimizer(llvm::TargetMachine *machine,
int opt_level,
llvm::FastMathFlags fast_math_flags,
bool print_passes)
: opt_level_(opt_level), print_passes_(print_passes), machine_(machine) {}
void LLVMModuleOptimizer::operator()(llvm::Module *m) {
auto machine = std::move(llvm::cantFail(
llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost())
.createTargetMachine()));
auto fpm = std::make_unique<CustomFunctionPassManager>(print_passes_, m);
// fpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis()));
// fpm->add(llvm::createInstructionCombiningPass());
// fpm->add(llvm::createReassociatePass());
// fpm->add(llvm::createGVNPass());
// fpm->add(llvm::createCFGSimplificationPass());
// fpm->add(llvm::createSROAPass());
// fpm->add(llvm::createEarlyCSEPass());
// fpm->add(llvm::createLowerExpectIntrinsicPass());
// fpm->add(llvm::createCallSiteSplittingPass());
// fpm->add(llvm::createLoopVectorizePass());
// fpm->add(llvm::createSLPVectorizerPass());
// fpm->add(llvm::createLoadStoreVectorizerPass());
// fpm->add(llvm::createLoopUnrollPass());
auto mpm = std::make_unique<CustomModulePassManager>(print_passes_);
// mpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis()));
// LOG(INFO) << "llvm run pass: target machine: name[" <<
// machine_->getTarget().getName() << "]"; LOG(INFO) << "llvm run pass: target
// machine: cpu[" << machine_->getTargetCPU().str() << "]";
fpm->add(llvm::createTargetTransformInfoWrapperPass(
machine->getTargetIRAnalysis()));
mpm->add(llvm::createTargetTransformInfoWrapperPass(
machine->getTargetIRAnalysis()));
auto builder = std::make_unique<llvm::PassManagerBuilder>();
builder->OptLevel = opt_level_;
builder->Inliner = llvm::createFunctionInliningPass();
builder->LoopVectorize = true;
builder->SLPVectorize = true;
#if LLVM_VERSION_MAJOR >= 11
machine->adjustPassManager(*builder);
#endif
builder->populateFunctionPassManager(*fpm);
builder->populateModulePassManager(*mpm);
fpm->doInitialization();
std::for_each(m->begin(), m->end(), [&fpm](auto &fn) { fpm->run(fn); });
fpm->doFinalization();
mpm->run(*m);
}
} // namespace cinn::backends
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/IR/Instruction.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/Pass.h>
#include <llvm/Target/TargetMachine.h>
#include <functional>
namespace cinn::backends {
// TODO(fc500110): define class OptimizeOptions
// llvm module optimizer
class LLVMModuleOptimizer final {
public:
explicit LLVMModuleOptimizer(llvm::TargetMachine *machine,
int opt_level,
llvm::FastMathFlags fast_math_flags,
bool print_passes = false);
void operator()(llvm::Module *m);
private:
llvm::TargetMachine *machine_;
int opt_level_{};
bool print_passes_{};
};
} // namespace cinn::backends
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include <glog/logging.h>
#include <llvm/Support/Alignment.h>
#include <atomic>
#include <mutex> //NOLINT
namespace cinn {
namespace backends {
using cinn::common::bfloat16;
using cinn::common::float16;
llvm::Type *CinnTypeToLLVMType(common::Type type,
llvm::Module *m,
bool is_vec) {
llvm::Type *ir_type = nullptr;
if (type.is_cpp_const()) {
// TODO(fc500110) support it latter.
}
llvm::Type *v = llvm::Type::getVoidTy(m->getContext());
llvm::Type *i1 = llvm::Type::getInt1Ty(m->getContext());
llvm::Type *i8 = llvm::Type::getInt8Ty(m->getContext());
llvm::Type *i16 = llvm::Type::getInt16Ty(m->getContext());
llvm::Type *i32 = llvm::Type::getInt32Ty(m->getContext());
llvm::Type *i64 = llvm::Type::getInt64Ty(m->getContext());
llvm::Type *u8 = llvm::Type::getInt8Ty(m->getContext());
llvm::Type *u16 = llvm::Type::getInt16Ty(m->getContext());
llvm::Type *u32 = llvm::Type::getInt32Ty(m->getContext());
llvm::Type *u64 = llvm::Type::getInt64Ty(m->getContext());
llvm::Type *bf16 = llvm::Type::getBFloatTy(m->getContext());
llvm::Type *f16 = llvm::Type::getHalfTy(m->getContext());
llvm::Type *f32 = llvm::Type::getFloatTy(m->getContext());
llvm::Type *f64 = llvm::Type::getDoubleTy(m->getContext());
llvm::Type *arr =
llvm::Type::getPrimitiveType(m->getContext(), llvm::Type::ArrayTyID);
if (type.is_void() && type.is_cpp_handle()) {
return llvm::PointerType::getUnqual(i8);
}
if (type.is_void() && type.is_cpp_handle2()) {
return llvm::PointerType::getUnqual(llvm::PointerType::getUnqual(i8));
}
if (type.is_bool()) {
ir_type = i1;
} else if (type.is_int(8)) {
ir_type = i8;
} else if (type.is_int(16)) {
ir_type = i16;
} else if (type.is_int(32)) {
ir_type = i32;
} else if (type.is_int(64)) {
ir_type = i64;
} else if (type.is_uint(8)) {
ir_type = u8;
} else if (type.is_uint(16)) {
ir_type = u16;
} else if (type.is_uint(32)) {
ir_type = u32;
} else if (type.is_uint(64)) {
ir_type = u64;
} else if (type.is_float(32)) {
ir_type = f32;
} else if (type.is_float(64)) {
ir_type = f64;
} else if (type.is_bfloat16()) {
ir_type = bf16;
} else if (type.is_float16()) {
ir_type = f16;
} else if (type.is_void()) {
ir_type = v;
} else if (type.is_string()) {
ir_type = arr;
} else if (type.is_customized_type()) {
CHECK(!type.customized_type().empty());
ir_type = m->getTypeByName("struct." + type.customized_type());
}
CHECK(ir_type) << "LLVM can't convert type: " << type;
// C array / vector.
if (type.lanes() > 1) {
if (is_vec) {
ir_type = llvm::FixedVectorType::get(ir_type, type.lanes());
} else {
ir_type = llvm::ArrayType::get(ir_type, type.lanes());
}
}
if (type.is_cpp_handle()) {
ir_type = llvm::PointerType::getUnqual(ir_type);
}
if (type.is_cpp_handle2()) {
ir_type = llvm::PointerType::getUnqual(ir_type);
ir_type = llvm::PointerType::getUnqual(ir_type);
}
return ir_type;
}
#define __(ty__) \
template <> \
llvm::Type *llvm_type_of<ty__>(llvm::Module * m) { \
return CinnTypeToLLVMType(common::type_of<ty__>(), m); \
}
__(int8_t)
__(int16_t)
__(int32_t)
__(int64_t)
__(uint8_t)
__(uint16_t)
__(uint32_t)
__(uint64_t)
__(bfloat16)
__(float16)
__(float)
__(double)
__(cinn_buffer_t)
__(cinn_buffer_t *)
__(cinn_pod_value_t *)
__(cinn_pod_value_t)
__(void *)
__(void **)
#undef __
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instruction.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
#include <string>
#include <type_traits>
#include <utility>
#include "paddle/cinn/common/type.h"
namespace cinn {
namespace backends {
template <typename T>
std::string DumpToString(const T &entity) {
std::string buffer;
llvm::raw_string_ostream os(buffer);
entity.print(os);
os.flush();
return buffer;
// return "\033[33m" + buffer + "\033[0m"; // Green
}
inline llvm::StringRef AsStringRef(absl::string_view str) {
return llvm::StringRef(str.data(), str.size());
}
llvm::Type *CinnTypeToLLVMType(common::Type t,
llvm::Module *m,
bool is_vec = false);
template <typename T>
llvm::Type *llvm_type_of(llvm::Module *m);
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include <absl/strings/string_view.h>
#include <glog/raw_logging.h>
#include <iostream>
#include "gflags/gflags_declare.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(verbose_function_register);
namespace cinn {
namespace backends {
RuntimeSymbols &GlobalSymbolRegistry::Global() {
static RuntimeSymbols symbols;
return symbols;
}
void *RuntimeSymbols::Lookup(absl::string_view name) const {
std::lock_guard<std::mutex> lock(mu_);
auto it = symbols_.find(std::string(name));
if (it != symbols_.end()) {
return it->second;
}
return nullptr;
}
void RuntimeSymbols::Register(const std::string &name, void *address) {
#ifdef CINN_WITH_DEBUG
if (FLAGS_verbose_function_register) {
RAW_LOG_INFO("JIT Register function [%s]: %p", name.c_str(), address);
}
#endif // CINN_WITH_DEBUG
std::lock_guard<std::mutex> lock(mu_);
auto it = symbols_.find(name);
if (it != symbols_.end()) {
CHECK_EQ(it->second, address)
<< "Duplicate register symbol [" << name << "]";
return;
}
symbols_.insert({name, reinterpret_cast<void *>(address)});
}
void RuntimeSymbols::Clear() {
std::lock_guard<std::mutex> lock(mu_);
symbols_.clear();
scalar_holder_.clear();
}
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include <absl/types/any.h>
#include <absl/types/variant.h>
#include <glog/logging.h>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "paddle/cinn/common/macros.h"
namespace cinn {
namespace backends {
class RuntimeSymbols {
public:
RuntimeSymbols() = default;
RuntimeSymbols(const RuntimeSymbols &) = delete;
RuntimeSymbols(RuntimeSymbols &&rhs) {
symbols_ = std::move(rhs.symbols_);
scalar_holder_ = std::move(rhs.scalar_holder_);
}
/**
* Register function address.
* @param name Name of the symbol.
* @param address Address of the function.
*/
void RegisterFn(const std::string &name, void *address) {
Register(name, address);
}
/**
* Register scalar.
* @tparam T Type of the scalar.
* @param name Name of the symbol.
* @param val Scalar value.
*/
template <typename T, typename = std::enable_if<std::is_pod<T>::value>>
void RegisterVar(const std::string &name, T val) {
void *data_ptr = nullptr;
{
std::lock_guard<std::mutex> lock(mu_);
auto &data = scalar_holder_[name];
data.resize(sizeof(T));
memcpy(data.data(), &val, sizeof(T));
data_ptr = reinterpret_cast<void *>(data.data());
}
Register(name, data_ptr);
}
/**
* Lookup a symbol from the registry.
* @param name Name of the symbol.
* @return The address if existes, or nullptr will return.
*/
void *Lookup(absl::string_view name) const;
/**
* Get all the symbols.
*/
const std::map<std::string, void *> &All() const { return symbols_; }
/**
* Clear all the symbols.
*/
void Clear();
private:
/**
* Register external symbol to the registry, the symbols in the registry will
* finally registered to JIT .
* @param name Name of the symbol in the JIT.
* @param address The address of the variable in external space.
*/
void Register(const std::string &name, void *address);
mutable std::mutex mu_;
std::map<std::string, void *> symbols_;
std::map<std::string, std::vector<int8_t>> scalar_holder_;
};
/**
* Registry for runtime symbols, these symbols will be inserted into JIT.
*/
class GlobalSymbolRegistry {
public:
static RuntimeSymbols &Global();
private:
GlobalSymbolRegistry() = default;
CINN_DISALLOW_COPY_AND_ASSIGN(GlobalSymbolRegistry);
};
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/simple_jit.h"
#include <llvm/AsmParser/Parser.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/Reassociate.h>
#include <llvm/Transforms/Scalar/SimplifyCFG.h>
#include <string>
#include <utility>
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace cinn {
namespace backends {
void SimpleJIT::AddModule(std::unique_ptr<llvm::Module> module, bool optimize) {
/*
for (auto &fn : module->functions()) {
LOG(INFO) << "fn:\n" << DumpToString(fn);
}
*/
CHECK(!llvm::verifyModule(*module, &llvm::errs()))
<< "Transformation resulted in an invalid module\n\nmodule:\n";
bool debug = false;
if (optimize) {
llvm::PassBuilder pass_builder;
llvm::LoopAnalysisManager loop_analysis_manager(debug);
llvm::FunctionAnalysisManager function_analysis_manager(debug);
llvm::CGSCCAnalysisManager cgscc_analysis_manager(debug);
llvm::ModuleAnalysisManager module_analysis_manager(debug);
pass_builder.registerModuleAnalyses(module_analysis_manager);
pass_builder.registerCGSCCAnalyses(cgscc_analysis_manager);
pass_builder.registerFunctionAnalyses(function_analysis_manager);
pass_builder.registerLoopAnalyses(loop_analysis_manager);
pass_builder.crossRegisterProxies(loop_analysis_manager,
function_analysis_manager,
cgscc_analysis_manager,
module_analysis_manager);
llvm::ModulePassManager module_pass_manager =
pass_builder.buildPerModuleDefaultPipeline(
llvm::PassBuilder::OptimizationLevel::O3);
module_pass_manager.run(*module, module_analysis_manager);
}
VLOG(3) << "jit target: " << jit_->getDataLayout().getStringRepresentation();
VLOG(3) << "module target: "
<< module->getDataLayout().getStringRepresentation();
llvm::orc::ThreadSafeModule tsm(std::move(module), context_);
llvm::cantFail(jit_->addIRModule(std::move(tsm)));
if (debug) {
std::string buffer;
llvm::raw_string_ostream os(buffer);
jit_->getExecutionSession().dump(os);
os.flush();
VLOG(3) << "compiled jit:\n" << buffer;
}
}
SimpleJIT::SimpleJIT() : context_(std::make_unique<llvm::LLVMContext>()) {
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
jit_ = llvm::cantFail(llvm::orc::LLJITBuilder().create());
CHECK(jit_) << "JIT create failed";
auto proc_symbols_generator = llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
jit_->getDataLayout().getGlobalPrefix()));
jit_->getMainJITDylib().addGenerator(std::move(proc_symbols_generator));
llvm::orc::MangleAndInterner mangle(jit_->getExecutionSession(),
jit_->getDataLayout());
for (auto &item : GlobalSymbolRegistry::Global().All()) {
VLOG(2) << "Insert [" << item.first << "] to SimpleJIT";
llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols(
{{mangle(item.first),
{llvm::pointerToJITTargetAddress(item.second),
llvm::JITSymbolFlags::None}}})));
}
}
template <typename CodeGenT>
void SimpleJIT::Link(ir::Module module, bool optimize) {
std::string runtime_ir(backends::kRuntimeLlvmIr);
llvm::SMDiagnostic error;
auto m = llvm::parseAssemblyString(runtime_ir, error, context());
m->setDataLayout(jit_->getDataLayout());
auto b = std::make_unique<llvm::IRBuilder<>>(context());
auto ir_emitter = std::make_unique<CodeGenT>(m.get(), b.get());
ir_emitter->Compile(module);
CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";
AddModule(std::move(m), optimize);
}
template void SimpleJIT::Link<CodeGenLLVM>(ir::Module module, bool optimize);
template void SimpleJIT::Link<CodeGenCUDA_Host>(ir::Module module,
bool optimize);
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace cinn {
namespace backends {
class SimpleJIT {
public:
static std::unique_ptr<SimpleJIT> Create() {
return std::unique_ptr<SimpleJIT>(new SimpleJIT);
}
/**
* Runtime link to a module.
* @tparam CodeGenT a CodeGenLLVM implementation.
* @param module a CINN module.
* @param optimize whether to optimize.
*/
template <typename CodeGenT = CodeGenLLVM>
void Link(ir::Module module, bool optimize = true);
void Link(llvm::orc::ThreadSafeModule m, bool optimize = true) {
llvm::cantFail(jit_->addIRModule(std::move(m)));
}
llvm::JITTargetAddress Lookup(absl::string_view name) {
return llvm::cantFail(jit_->lookup(AsStringRef(name))).getAddress();
}
private:
void AddModule(std::unique_ptr<llvm::Module> module, bool optimize);
llvm::LLVMContext &context() { return *context_.getContext(); }
SimpleJIT();
std::unique_ptr<llvm::orc::LLJIT> jit_;
llvm::orc::ThreadSafeContext context_;
};
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/modular.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace cinn {
namespace backends {
class ModularEvaluator : public ir::IRVisitorRequireReImpl<ModularEntry> {
public:
explicit ModularEvaluator(const std::map<Var, ModularEntry>& mod_map)
: mod_map_(mod_map) {}
ModularEntry Eval(const Expr& e) {
return ir::IRVisitorRequireReImpl<ModularEntry>::Visit(&e);
}
ModularEntry Visit(const ir::IntImm* op) {
if (op->value < std::numeric_limits<int>::max()) {
return ModularEntry{static_cast<int>(op->value), 0};
}
return ModularEntry::everything();
}
ModularEntry Visit(const ir::UIntImm* op) {
if (op->value < std::numeric_limits<uint64_t>::max()) {
return ModularEntry{static_cast<int>(op->value), 0};
}
return ModularEntry::everything();
}
ModularEntry Visit(const ir::_Var_* op) {
Var var(&Reference(op));
auto it = mod_map_.find(var);
if (it != mod_map_.end()) return it->second;
return ModularEntry::everything();
}
ModularEntry Visit(const ir::Add* op) {
auto a = Eval(op->a());
auto b = Eval(op->b());
ModularEntry ret;
ret.coeff = gcd(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry Visit(const ir::Sub* op) {
auto a = Eval(op->a());
auto b = Eval(op->b());
ModularEntry ret;
ret.coeff = gcd(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
}
ModularEntry Visit(const ir::Mul* op) {
auto a = Eval(op->a());
auto b = Eval(op->b());
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ModularEntry ret;
ret.coeff = gcd(pq, gcd(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
}
ModularEntry Visit(const ir::Div* op) {
auto a = Eval(op->a());
auto b = Eval(op->b());
if (b.coeff % b.base == 0) {
ModularEntry ret;
ret.coeff = a.coeff / b.base;
ret.base = 0;
return ret;
}
return ModularEntry::everything();
}
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int gcd(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
private:
const std::map<Var, ModularEntry>& mod_map_;
};
ModularEntry ModularEntry::Add(const ModularEntry& a, const ModularEntry& b) {
ModularEntry ret;
ret.coeff = ModularEvaluator::gcd(a.coeff, b.coeff);
ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include "paddle/cinn/ir/ir.h"
namespace cinn {
namespace backends {
// borrowed from Halide and TVM.
struct ModularEntry {
int base;
int coeff;
ModularEntry() = default;
ModularEntry(int base, int coeff) : base(base), coeff(coeff) {}
static ModularEntry everything() { return ModularEntry{0, 1}; }
static ModularEntry Add(const ModularEntry& a, const ModularEntry& b);
};
ModularEntry EvalModular(const Expr& e,
const std::map<Var, ModularEntry>& mod_map);
} // namespace backends
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS header_generator.cc nvrtc_util.cc)
cinn_nv_test(test_nvrtc_util SRCS nvrtc_util_test.cc DEPS cinncore)
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/nvrtc/header_generator.h"
#include "glog/logging.h"
#include "jitify.hpp" // NOLINT
namespace cinn {
namespace backends {
namespace nvrtc {
HeaderGeneratorBase& JitSafeHeaderGenerator::GetInstance() {
static JitSafeHeaderGenerator instance;
return instance;
}
const size_t JitSafeHeaderGenerator::size() const {
CHECK_EQ(include_names_.size(), headers_.size())
<< "Internal error in size of header files.";
return include_names_.size();
}
JitSafeHeaderGenerator::JitSafeHeaderGenerator() {
const auto& headers_map = ::jitify::detail::get_jitsafe_headers_map();
for (auto& pair : headers_map) {
include_names_.emplace_back(pair.first.data());
headers_.emplace_back(pair.second.data());
}
}
} // namespace nvrtc
} // namespace backends
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace cinn {
namespace backends {
class HeaderGeneratorBase {
public:
virtual const size_t size() const = 0;
virtual const std::vector<const char*>& headers() const = 0;
virtual const std::vector<const char*>& include_names() const = 0;
};
namespace nvrtc {
class JitSafeHeaderGenerator : public HeaderGeneratorBase {
public:
static HeaderGeneratorBase& GetInstance();
const size_t size() const;
const std::vector<const char*>& headers() const override { return headers_; }
const std::vector<const char*>& include_names() const override {
return include_names_;
}
private:
JitSafeHeaderGenerator();
std::vector<const char*> headers_;
std::vector<const char*> include_names_;
};
} // namespace nvrtc
} // namespace backends
} // namespace cinn
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <nvrtc.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/nvrtc/header_generator.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string(cinn_nvcc_cmd_path);
DECLARE_bool(nvrtc_compile_to_cubin);
namespace cinn {
namespace backends {
namespace nvrtc {
std::string Compiler::operator()(const std::string& code,
bool include_headers) {
if (runtime::CanUseNvccCompiler()) {
return CompileWithNvcc(code);
}
return CompileCudaSource(code, include_headers);
}
Compiler::Compiler() {
if (FLAGS_nvrtc_compile_to_cubin) {
#if CUDA_VERSION >= 11010
compile_to_cubin_ = true;
#endif
}
VLOG(4) << "FLAGS_nvrtc_compile_to_cubin: " << FLAGS_nvrtc_compile_to_cubin
<< ", compile_to_cubin_: " << compile_to_cubin_;
}
bool Compiler::compile_to_cubin() { return compile_to_cubin_; }
std::vector<std::string> Compiler::FindCUDAIncludePaths() {
const std::string delimiter = "/";
std::string cuda_include_path;
const char* cuda_path_env = std::getenv("CUDA_PATH");
if (cuda_path_env != nullptr) {
cuda_include_path += cuda_path_env;
cuda_include_path += delimiter + "include";
return {cuda_include_path};
}
#if defined(__linux__)
struct stat st;
cuda_include_path = "/usr/local/cuda/include";
if (stat(cuda_include_path.c_str(), &st) == 0) {
return {cuda_include_path};
}
#endif
LOG(FATAL) << "Cannot find cuda include path."
<< "CUDA_PATH is not set or CUDA is not installed in the default "
"installation path."
<< "In other than linux, it is necessary to set CUDA_PATH.";
return {cuda_include_path};
}
std::vector<std::string> Compiler::FindCINNRuntimeIncludePaths() {
return {Context::Global().runtime_include_dir()};
}
std::string Compiler::CompileCudaSource(const std::string& code,
bool include_headers) {
const auto& header_gen = JitSafeHeaderGenerator::GetInstance();
std::vector<std::string> compile_options;
std::vector<const char*> param_cstrings{};
nvrtcProgram prog;
std::string cc = "30";
int major, minor;
cudaError_t e1 =
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
cudaError_t e2 =
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
if (e1 == cudaSuccess && e2 == cudaSuccess) {
cc = std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}
if (compile_to_cubin_) {
compile_options.push_back("-arch=sm_" + cc);
} else {
compile_options.push_back("-arch=compute_" + cc);
}
compile_options.push_back("-std=c++14");
compile_options.push_back("-default-device");
if (include_headers) { // prepare include headers
auto cuda_headers = FindCUDAIncludePaths();
auto cinn_headers = FindCINNRuntimeIncludePaths();
std::vector<std::string> include_paths;
for (auto& header : cuda_headers) {
include_paths.push_back("--include-path=" + header);
}
for (auto& header : cinn_headers) {
include_paths.push_back("--include-path=" + header);
}
compile_options.insert(
std::end(compile_options), include_paths.begin(), include_paths.end());
}
for (const auto& option : compile_options) {
param_cstrings.push_back(option.c_str());
}
VLOG(3) << "compile options: " << utils::Join(compile_options, " ");
NVRTC_CALL(nvrtcCreateProgram(&prog,
code.c_str(),
nullptr,
header_gen.size(),
header_gen.headers().data(),
header_gen.include_names().data()));
nvrtcResult compile_res =
nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
{ // get log
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
}
size_t size;
std::string data;
if (compile_to_cubin_) {
NVRTC_CALL(nvrtcGetCUBINSize(prog, &size));
data.resize(size);
NVRTC_CALL(nvrtcGetCUBIN(prog, &data[0]));
} else {
NVRTC_CALL(nvrtcGetPTXSize(prog, &size));
data.resize(size);
NVRTC_CALL(nvrtcGetPTX(prog, &data[0]));
}
NVRTC_CALL(nvrtcDestroyProgram(&prog));
return data;
}
std::string Compiler::CompileWithNvcc(const std::string& cuda_c) {
// read dir source
std::string dir = "./source";
if (access(dir.c_str(), 0) == -1) {
CHECK(mkdir(dir.c_str(), 7) != -1) << "Fail to mkdir " << dir;
}
// get unqiue prefix name
prefix_name_ = dir + "/" + common::UniqName("rtc_tmp");
auto cuda_c_file = prefix_name_ + ".cu";
std::ofstream ofs(cuda_c_file, std::ios::out);
CHECK(ofs.is_open()) << "Fail to open file " << cuda_c_file;
ofs << cuda_c;
ofs.close();
CompileToPtx();
CompileToCubin();
return prefix_name_ + ".cubin";
}
// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx",
// std::ios::in); }
void Compiler::CompileToPtx() {
auto include_dir = common::Context::Global().runtime_include_dir();
std::string include_dir_str = "";
for (auto dir : include_dir) {
if (include_dir_str.empty()) {
include_dir_str = dir;
} else {
include_dir_str += ":" + dir;
}
}
std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path +
std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") +
include_dir_str;
options += " -arch=" + GetDeviceArch();
options += " -o " + prefix_name_ + ".ptx";
options += " " + prefix_name_ + ".cu";
VLOG(2) << "Nvcc Compile Options : " << options;
CHECK(system(options.c_str()) == 0) << options;
}
void Compiler::CompileToCubin() {
std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path +
std::string(":$PATH && nvcc --cubin -O3");
options += " -arch=" + GetDeviceArch();
options += " -o " + prefix_name_ + ".cubin";
options += " " + prefix_name_ + ".ptx";
VLOG(2) << "Nvcc Compile Options : " << options;
CHECK(system(options.c_str()) == 0) << options;
}
std::string Compiler::GetDeviceArch() {
int major = 0, minor = 0;
if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) ==
cudaSuccess &&
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) ==
cudaSuccess) {
return "sm_" + std::to_string(major) + std::to_string(minor);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
return "sm_30";
}
}
std::string Compiler::ReadFile(const std::string& file_name,
std::ios_base::openmode mode) {
// open cubin file
std::ifstream ifs(file_name, mode);
CHECK(ifs.is_open()) << "Fail to open file " << file_name;
ifs.seekg(std::ios::end);
auto len = ifs.tellg();
ifs.seekg(0);
// read cubin file
std::string file_data(len, ' ');
ifs.read(&file_data[0], len);
ifs.close();
return std::move(file_data);
}
} // namespace nvrtc
} // namespace backends
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment