Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
992bec46
Commit
992bec46
authored
Oct 08, 2023
by
“yuguo”
Browse files
2.5
parent
0259837d
Changes
357
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2671 additions
and
0 deletions
+2671
-0
paddle/cinn/backends/llvm/execution_engine.cc
paddle/cinn/backends/llvm/execution_engine.cc
+275
-0
paddle/cinn/backends/llvm/execution_engine.h
paddle/cinn/backends/llvm/execution_engine.h
+111
-0
paddle/cinn/backends/llvm/execution_engine_test.cc
paddle/cinn/backends/llvm/execution_engine_test.cc
+352
-0
paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py
paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py
+64
-0
paddle/cinn/backends/llvm/ir_builder_mixin.h
paddle/cinn/backends/llvm/ir_builder_mixin.h
+308
-0
paddle/cinn/backends/llvm/llvm_intrin_rule.h
paddle/cinn/backends/llvm/llvm_intrin_rule.h
+194
-0
paddle/cinn/backends/llvm/llvm_optimizer.cc
paddle/cinn/backends/llvm/llvm_optimizer.cc
+174
-0
paddle/cinn/backends/llvm/llvm_optimizer.h
paddle/cinn/backends/llvm/llvm_optimizer.h
+43
-0
paddle/cinn/backends/llvm/llvm_util.cc
paddle/cinn/backends/llvm/llvm_util.cc
+149
-0
paddle/cinn/backends/llvm/llvm_util.h
paddle/cinn/backends/llvm/llvm_util.h
+59
-0
paddle/cinn/backends/llvm/runtime_symbol_registry.cc
paddle/cinn/backends/llvm/runtime_symbol_registry.cc
+69
-0
paddle/cinn/backends/llvm/runtime_symbol_registry.h
paddle/cinn/backends/llvm/runtime_symbol_registry.h
+116
-0
paddle/cinn/backends/llvm/simple_jit.cc
paddle/cinn/backends/llvm/simple_jit.cc
+143
-0
paddle/cinn/backends/llvm/simple_jit.h
paddle/cinn/backends/llvm/simple_jit.h
+86
-0
paddle/cinn/backends/modular.cc
paddle/cinn/backends/modular.cc
+131
-0
paddle/cinn/backends/modular.h
paddle/cinn/backends/modular.h
+41
-0
paddle/cinn/backends/nvrtc/CMakeLists.txt
paddle/cinn/backends/nvrtc/CMakeLists.txt
+5
-0
paddle/cinn/backends/nvrtc/header_generator.cc
paddle/cinn/backends/nvrtc/header_generator.cc
+45
-0
paddle/cinn/backends/nvrtc/header_generator.h
paddle/cinn/backends/nvrtc/header_generator.h
+49
-0
paddle/cinn/backends/nvrtc/nvrtc_util.cc
paddle/cinn/backends/nvrtc/nvrtc_util.cc
+257
-0
No files found.
Too many changes to show.
To preserve performance only
357 of 357+
files are displayed.
Plain diff
Email patch
paddle/cinn/backends/llvm/execution_engine.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include <absl/strings/string_view.h>
#include <llvm/ADT/Triple.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/Config/llvm-config.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/InitializePasses.h>
#include <llvm/PassRegistry.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/Host.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/NewGVN.h>
#include <llvm/Transforms/Scalar/Reassociate.h>
#include <llvm/Transforms/Scalar/SimplifyCFG.h>
#include <cmath>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <utility>
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/profiler.h"
namespace
cinn
::
backends
{
namespace
{
void
InitializeLLVMPasses
()
{
llvm
::
InitializeNativeTarget
();
llvm
::
InitializeNativeTargetAsmPrinter
();
auto
&
registry
=
*
llvm
::
PassRegistry
::
getPassRegistry
();
llvm
::
initializeCore
(
registry
);
llvm
::
initializeTransformUtils
(
registry
);
llvm
::
initializeScalarOpts
(
registry
);
llvm
::
initializeIPO
(
registry
);
llvm
::
initializeInstCombine
(
registry
);
llvm
::
initializeAggressiveInstCombine
(
registry
);
llvm
::
initializeAnalysis
(
registry
);
llvm
::
initializeVectorization
(
registry
);
llvm
::
initializeSROALegacyPassPass
(
registry
);
// llvm::initializeCodeGen(registry);
// llvm::initializeTarget(registry);
// llvm::initializeCodeGenPreparePass(registry);
}
}
// namespace
void
NaiveObjectCache
::
notifyObjectCompiled
(
const
llvm
::
Module
*
m
,
llvm
::
MemoryBufferRef
obj_buffer
)
{
cached_objects_
[
m
->
getModuleIdentifier
()]
=
llvm
::
MemoryBuffer
::
getMemBufferCopy
(
obj_buffer
.
getBuffer
(),
obj_buffer
.
getBufferIdentifier
());
}
std
::
unique_ptr
<
llvm
::
MemoryBuffer
>
NaiveObjectCache
::
getObject
(
const
llvm
::
Module
*
m
)
{
auto
it
=
cached_objects_
.
find
(
m
->
getModuleIdentifier
());
if
(
it
==
cached_objects_
.
end
())
{
VLOG
(
1
)
<<
"No object for "
<<
m
->
getModuleIdentifier
()
<<
" in cache. Compiling."
;
return
nullptr
;
}
VLOG
(
3
)
<<
"Object for "
<<
m
->
getModuleIdentifier
()
<<
" loaded from cache."
;
return
llvm
::
MemoryBuffer
::
getMemBuffer
(
it
->
second
->
getMemBufferRef
());
}
/*static*/
std
::
unique_ptr
<
ExecutionEngine
>
ExecutionEngine
::
Create
(
const
ExecutionOptions
&
config
)
{
return
Create
(
config
,
{});
}
/*static*/
std
::
unique_ptr
<
ExecutionEngine
>
ExecutionEngine
::
Create
(
const
ExecutionOptions
&
config
,
RuntimeSymbols
&&
module_symbols
)
{
VLOG
(
1
)
<<
"===================== Create CINN ExecutionEngine begin "
"===================="
;
VLOG
(
1
)
<<
"initialize llvm config"
;
VLOG
(
1
)
<<
"llvm version: "
<<
LLVM_VERSION_STRING
;
VLOG
(
1
)
<<
"llvm default target triple: "
<<
LLVM_DEFAULT_TARGET_TRIPLE
;
static
std
::
once_flag
flag
;
std
::
call_once
(
flag
,
InitializeLLVMPasses
);
auto
engine
=
std
::
make_unique
<
ExecutionEngine
>
(
/*enable_object_cache=*/
true
,
std
::
move
(
module_symbols
));
auto
compile_layer_creator
=
[
&
engine
](
llvm
::
orc
::
JITTargetMachineBuilder
jtmb
)
->
llvm
::
Expected
<
std
::
unique_ptr
<
llvm
::
orc
::
IRCompileLayer
::
IRCompiler
>>
{
auto
machine
=
llvm
::
cantFail
(
jtmb
.
createTargetMachine
());
VLOG
(
1
)
<<
"create llvm compile layer"
;
VLOG
(
1
)
<<
"Target Name: "
<<
machine
->
getTarget
().
getName
();
VLOG
(
1
)
<<
"Target CPU: "
<<
machine
->
getTargetCPU
().
str
()
<<
std
::
endl
;
return
std
::
make_unique
<
llvm
::
orc
::
TMOwningSimpleCompiler
>
(
std
::
move
(
machine
),
engine
->
cache_
.
get
());
};
auto
object_layer_creator
=
[
&
](
llvm
::
orc
::
ExecutionSession
&
session
,
const
llvm
::
Triple
&
triple
)
{
auto
object_layer
=
std
::
make_unique
<
llvm
::
orc
::
RTDyldObjectLinkingLayer
>
(
session
,
[]()
{
return
std
::
make_unique
<
llvm
::
SectionMemoryManager
>
();
});
llvm
::
orc
::
JITDylib
*
main_jd
=
session
.
getJITDylibByName
(
"<main>"
);
if
(
!
main_jd
)
{
main_jd
=
&
llvm
::
cantFail
(
session
.
createJITDylib
(
"<main>"
));
}
return
object_layer
;
};
VLOG
(
2
)
<<
"create jit execution engine"
;
engine
->
jit_
=
llvm
::
cantFail
(
llvm
::
orc
::
LLJITBuilder
()
.
setCompileFunctionCreator
(
compile_layer_creator
)
.
setObjectLinkingLayerCreator
(
object_layer_creator
)
.
create
());
engine
->
jit_
->
getMainJITDylib
().
addGenerator
(
llvm
::
cantFail
(
llvm
::
orc
::
DynamicLibrarySearchGenerator
::
GetForCurrentProcess
(
engine
->
jit_
->
getDataLayout
().
getGlobalPrefix
())));
VLOG
(
2
)
<<
"register runtime call symbols"
;
engine
->
RegisterRuntimeSymbols
();
VLOG
(
2
)
<<
"===================== Create CINN ExecutionEngine end "
"===================="
;
return
engine
;
}
template
<
typename
CodeGenT
>
void
ExecutionEngine
::
Link
(
const
ir
::
Module
&
module
)
{
utils
::
RecordEvent
(
"ExecutionEngine Link"
,
utils
::
EventType
::
kOrdinary
);
llvm
::
SMDiagnostic
error
;
auto
ctx
=
std
::
make_unique
<
llvm
::
LLVMContext
>
();
auto
m
=
llvm
::
parseAssemblyString
(
AsStringRef
(
backends
::
kRuntimeLlvmIr
),
error
,
*
ctx
);
auto
b
=
std
::
make_unique
<
llvm
::
IRBuilder
<>>
(
*
ctx
);
auto
ir_emitter
=
std
::
make_unique
<
CodeGenT
>
(
m
.
get
(),
b
.
get
());
VLOG
(
3
)
<<
"ir_emitter->Compile(module) Begin"
;
ir_emitter
->
Compile
(
module
);
VLOG
(
3
)
<<
"ir_emitter->Compile(module) Succeed!"
;
CHECK
(
!
llvm
::
verifyModule
(
*
m
,
&
llvm
::
errs
()))
<<
"Invalid module found"
;
auto
machine
=
std
::
move
(
llvm
::
cantFail
(
llvm
::
cantFail
(
llvm
::
orc
::
JITTargetMachineBuilder
::
detectHost
())
.
createTargetMachine
()));
LLVMModuleOptimizer
optimize
(
machine
.
get
(),
3
,
{},
true
);
optimize
(
m
.
get
());
CHECK
(
!
llvm
::
verifyModule
(
*
m
,
&
llvm
::
errs
()))
<<
"Invalid optimized module detected"
;
for
(
auto
&
f
:
*
m
)
{
VLOG
(
5
)
<<
"function: "
<<
DumpToString
(
f
);
}
llvm
::
raw_svector_ostream
rawstream
(
buffer_
);
llvm
::
legacy
::
PassManager
pass_manager
;
machine
->
addPassesToEmitFile
(
pass_manager
,
rawstream
,
nullptr
,
llvm
::
CGFT_ObjectFile
);
pass_manager
.
run
(
*
m
);
CHECK
(
AddModule
(
std
::
move
(
m
),
std
::
move
(
ctx
)));
if
(
VLOG_IS_ON
(
5
))
{
VLOG
(
5
)
<<
"======= dump jit execution session ======"
;
std
::
string
buffer
;
llvm
::
raw_string_ostream
os
(
buffer
);
decltype
(
auto
)
es
=
jit_
->
getExecutionSession
();
es
.
dump
(
os
);
os
.
flush
();
VLOG
(
5
)
<<
buffer
;
}
}
bool
ExecutionEngine
::
AddModule
(
std
::
unique_ptr
<
llvm
::
Module
>
module
,
std
::
unique_ptr
<
llvm
::
LLVMContext
>
context
)
{
utils
::
RecordEvent
(
"ExecutionEngine AddModule"
,
utils
::
EventType
::
kOrdinary
);
module
->
setDataLayout
(
jit_
->
getDataLayout
());
if
(
VLOG_IS_ON
(
5
))
{
VLOG
(
5
)
<<
"======= dump jit lib =========="
;
std
::
string
buffer
;
llvm
::
raw_string_ostream
os
(
buffer
);
module
->
print
(
os
,
{});
// main_jd_->dump(os);
os
.
flush
();
VLOG
(
5
)
<<
buffer
;
}
llvm
::
orc
::
ThreadSafeContext
tsc
(
std
::
move
(
context
));
llvm
::
orc
::
ThreadSafeModule
tsm
(
std
::
move
(
module
),
std
::
move
(
tsc
));
llvm
::
cantFail
(
jit_
->
addIRModule
(
std
::
move
(
tsm
)));
return
true
;
}
void
ExecutionEngine
::
ExportObject
(
const
std
::
string
&
path
)
{
FILE
*
of
=
fopen
(
path
.
c_str
(),
"w"
);
fwrite
(
buffer_
.
data
(),
1
,
buffer_
.
size
(),
of
);
fclose
(
of
);
}
void
*
ExecutionEngine
::
Lookup
(
absl
::
string_view
name
)
{
utils
::
RecordEvent
(
"ExecutionEngine Lookup"
,
utils
::
EventType
::
kOrdinary
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
auto
symbol
=
jit_
->
lookup
(
AsStringRef
(
name
)))
{
return
reinterpret_cast
<
void
*>
(
symbol
->
getAddress
());
}
LOG
(
ERROR
)
<<
"Unknown symbol name["
<<
name
<<
"]"
;
return
nullptr
;
}
void
ExecutionEngine
::
RegisterRuntimeSymbols
()
{
utils
::
RecordEvent
(
"ExecutionEngine RegisterRuntimeSymbols"
,
utils
::
EventType
::
kOrdinary
);
const
auto
&
registry
=
GlobalSymbolRegistry
::
Global
();
auto
*
session
=
&
jit_
->
getExecutionSession
();
for
(
const
auto
&
sym
:
registry
.
All
())
{
llvm
::
cantFail
(
jit_
->
define
(
llvm
::
orc
::
absoluteSymbols
(
{{
session
->
intern
(
sym
.
first
),
{
llvm
::
pointerToJITTargetAddress
(
sym
.
second
),
llvm
::
JITSymbolFlags
::
None
}}})));
}
for
(
const
auto
&
sym
:
module_symbols_
.
All
())
{
llvm
::
cantFail
(
jit_
->
define
(
llvm
::
orc
::
absoluteSymbols
(
{{
session
->
intern
(
sym
.
first
),
{
llvm
::
pointerToJITTargetAddress
(
sym
.
second
),
llvm
::
JITSymbolFlags
::
None
}}})));
}
}
template
void
ExecutionEngine
::
Link
<
CodeGenLLVM
>(
const
ir
::
Module
&
module
);
template
void
ExecutionEngine
::
Link
<
CodeGenX86
>(
const
ir
::
Module
&
module
);
template
void
ExecutionEngine
::
Link
<
CodeGenCUDA_Host
>(
const
ir
::
Module
&
module
);
}
// namespace cinn::backends
paddle/cinn/backends/llvm/execution_engine.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/StringMap.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/ObjectCache.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <optional>
#include <string>
#include <vector>
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/module.h"
namespace
cinn
::
backends
{
class
NaiveObjectCache
:
public
llvm
::
ObjectCache
{
public:
void
notifyObjectCompiled
(
const
llvm
::
Module
*
,
llvm
::
MemoryBufferRef
)
override
;
std
::
unique_ptr
<
llvm
::
MemoryBuffer
>
getObject
(
const
llvm
::
Module
*
)
override
;
private:
llvm
::
StringMap
<
std
::
unique_ptr
<
llvm
::
MemoryBuffer
>>
cached_objects_
;
};
struct
ExecutionOptions
{
int
opt_level
{
3
};
bool
enable_debug_info
{
false
};
// TODO(fc500110)
// int num_compile_threads{1};
// bool enable_fast_math;
};
class
ExecutionEngine
{
public:
static
std
::
unique_ptr
<
ExecutionEngine
>
Create
(
const
ExecutionOptions
&
config
);
static
std
::
unique_ptr
<
ExecutionEngine
>
Create
(
const
ExecutionOptions
&
config
,
RuntimeSymbols
&&
module_symbols
);
void
*
Lookup
(
absl
::
string_view
name
);
template
<
typename
CodeGenT
=
CodeGenLLVM
>
void
Link
(
const
ir
::
Module
&
module
);
void
ExportObject
(
const
std
::
string
&
path
);
bool
AddModule
(
std
::
unique_ptr
<
llvm
::
Module
>
module
,
std
::
unique_ptr
<
llvm
::
LLVMContext
>
context
);
protected:
explicit
ExecutionEngine
(
bool
enable_object_cache
,
RuntimeSymbols
&&
module_symbols
)
:
cache_
(
std
::
make_unique
<
NaiveObjectCache
>
()),
module_symbols_
(
std
::
move
(
module_symbols
))
{}
void
RegisterRuntimeSymbols
();
bool
SetupTargetTriple
(
llvm
::
Module
*
module
);
// This may not be a compatible implementation.
friend
std
::
unique_ptr
<
ExecutionEngine
>
std
::
make_unique
<
ExecutionEngine
>
(
bool
&&
,
cinn
::
backends
::
RuntimeSymbols
&&
);
private:
mutable
std
::
mutex
mu_
;
llvm
::
SmallString
<
0
>
buffer_
;
std
::
unique_ptr
<
llvm
::
orc
::
LLJIT
>
jit_
;
std
::
unique_ptr
<
NaiveObjectCache
>
cache_
;
RuntimeSymbols
module_symbols_
;
};
}
// namespace cinn::backends
paddle/cinn/backends/llvm/execution_engine_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include <glog/logging.h>
#include <glog/raw_logging.h>
#include <gtest/gtest.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/raw_ostream.h>
#include <algorithm>
#include <cmath>
#include <iomanip>
#include <memory>
#include <random>
#include <tuple>
#include <utility>
#include <vector>
#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/runtime/cpu/host_intrinsics.h"
#include "paddle/cinn/runtime/cpu/use_extern_funcs.h"
namespace
cinn
{
namespace
backends
{
namespace
{
bool
RegisterKnownSymbols
()
{
decltype
(
auto
)
registry
=
GlobalSymbolRegistry
::
Global
();
registry
.
RegisterFn
(
"sinf"
,
reinterpret_cast
<
void
*>
(
&
sinf
));
registry
.
RegisterFn
(
"sin"
,
reinterpret_cast
<
void
*>
(
static_cast
<
double
(
*
)(
double
)
>
(
&
sin
)));
registry
.
RegisterFn
(
"cosf"
,
reinterpret_cast
<
void
*>
(
&
cosf
));
registry
.
RegisterFn
(
"cos"
,
reinterpret_cast
<
void
*>
(
static_cast
<
double
(
*
)(
double
)
>
(
&
cos
)));
return
true
;
}
[[
maybe_unused
]]
bool
unused
=
RegisterKnownSymbols
();
constexpr
int
kM
=
100
;
constexpr
int
kN
=
32
;
auto
CreateTestBuffer
()
{
auto
*
A
=
cinn_buffer_t
::
new_
(
cinn_device_kind_t
::
cinn_x86_device
,
cinn_float32_t
(),
{
kM
,
kN
},
32
);
auto
*
B
=
cinn_buffer_t
::
new_
(
cinn_device_kind_t
::
cinn_x86_device
,
cinn_float32_t
(),
{
kM
,
kN
},
32
);
auto
*
C
=
cinn_buffer_t
::
new_
(
cinn_device_kind_t
::
cinn_x86_device
,
cinn_float32_t
(),
{
kM
,
kN
},
32
);
cinn_buffer_malloc
(
nullptr
,
A
);
cinn_buffer_malloc
(
nullptr
,
B
);
cinn_buffer_malloc
(
nullptr
,
C
);
float
*
Ad
=
reinterpret_cast
<
float
*>
(
A
->
memory
);
float
*
Bd
=
reinterpret_cast
<
float
*>
(
B
->
memory
);
for
(
int
i
=
0
;
i
<
A
->
num_elements
();
i
++
)
{
Ad
[
i
]
=
static_cast
<
float
>
(
rand
())
/
RAND_MAX
;
// NOLINT
Bd
[
i
]
=
static_cast
<
float
>
(
rand
())
/
RAND_MAX
;
// NOLINT
}
float
*
Cd
=
reinterpret_cast
<
float
*>
(
C
->
memory
);
CHECK_EQ
(
C
->
num_elements
(),
A
->
num_elements
());
return
std
::
make_tuple
(
A
,
B
,
C
);
}
auto
CreateTestCinnModule
()
{
ir
::
Expr
M
(
kM
);
ir
::
Expr
N
(
kN
);
lang
::
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
});
lang
::
Placeholder
<
float
>
B
(
"B"
,
{
M
,
N
});
lang
::
Buffer
C_buf
(
Float
(
32
));
auto
C
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
+
B
(
i
,
j
);
},
"C"
);
C
->
Bind
(
C_buf
);
common
::
Target
target
;
target
.
arch
=
common
::
Target
::
Arch
::
X86
;
target
.
bits
=
common
::
Target
::
Bit
::
k32
;
target
.
os
=
common
::
Target
::
OS
::
Linux
;
ir
::
Module
::
Builder
builder
(
"module1"
,
target
);
auto
stages
=
CreateStages
({
C
});
auto
funcs
=
lang
::
Lower
(
"elementwise_add"
,
stages
,
{
A
,
B
,
C
});
// auto func = optim::Optimize(funcs);
builder
.
AddFunction
(
ir
::
LoweredFunc
(
funcs
.
As
<
ir
::
_LoweredFunc_
>
()));
return
builder
.
Build
();
}
}
// namespace
TEST
(
llvm_test01
,
elementwise_add
)
{
return
;
auto
engine
=
backends
::
ExecutionEngine
::
Create
({
1
});
auto
_a_b_c_
=
CreateTestBuffer
();
// NOLINT
auto
&
a
=
std
::
get
<
0
>
(
_a_b_c_
);
auto
&
b
=
std
::
get
<
1
>
(
_a_b_c_
);
auto
&
c
=
std
::
get
<
2
>
(
_a_b_c_
);
auto
module
=
CreateTestCinnModule
();
engine
->
Link
(
module
);
auto
elementwise_add_addr
=
engine
->
Lookup
(
"elementwise_add"
);
return
;
auto
elementwise_add
=
reinterpret_cast
<
void
(
*
)(
void
*
,
int32_t
)
>
(
elementwise_add_addr
);
cinn_pod_value_t
a_arg
(
a
),
b_arg
(
b
),
c_arg
(
c
);
cinn_pod_value_t
args
[
3
]
=
{
a_arg
,
b_arg
,
c_arg
};
elementwise_add
(
args
,
3
);
float
*
ad
=
reinterpret_cast
<
float
*>
(
a
->
memory
);
float
*
bd
=
reinterpret_cast
<
float
*>
(
b
->
memory
);
float
*
cd
=
reinterpret_cast
<
float
*>
(
c
->
memory
);
for
(
int
i
=
0
;
i
<
c
->
num_elements
();
i
++
)
{
EXPECT_EQ
(
ad
[
i
]
+
bd
[
i
],
cd
[
i
]);
}
}
TEST
(
llvm
,
module_call_lowered_func
)
{
ir
::
Module
::
Builder
builder
(
"some_module"
,
common
::
DefaultHostTarget
());
ir
::
Expr
M
(
kM
);
ir
::
Expr
N
(
kN
);
{
// define fn
lang
::
Placeholder
<
float
>
a
(
"A"
,
{
M
,
N
});
lang
::
Placeholder
<
float
>
b
(
"B"
,
{
M
,
N
});
auto
c
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
auto
i
,
auto
j
)
{
return
a
(
i
,
j
)
+
b
(
i
,
j
);
},
"C"
);
auto
stages
=
CreateStages
({
c
});
auto
fn
=
lang
::
Lower
(
"elementwise_add"
,
stages
,
{
a
,
b
,
c
},
{});
builder
.
AddFunction
(
fn
);
}
{
// call fn
lang
::
Placeholder
<
float
>
a
(
"A"
,
{
M
,
N
});
lang
::
Placeholder
<
float
>
b
(
"B"
,
{
M
,
N
});
std
::
vector
<
lang
::
ReturnType
>
ret_types
(
{
lang
::
ReturnType
{
Float
(
32
),
{
M
,
N
},
"c_out"
}});
auto
call_outs
=
lang
::
CallLowered
(
"elementwise_add"
,
{
a
,
b
},
ret_types
);
auto
c
=
call_outs
[
0
];
// here we must call the output, so that it cal output something.
auto
stages
=
CreateStages
({
c
});
auto
main_fn
=
lang
::
Lower
(
"main"
,
stages
,
{
a
,
b
,
c
},
{});
builder
.
AddFunction
(
main_fn
);
CodeGenC
codegen
(
common
::
DefaultHostTarget
());
codegen
.
SetInlineBuiltinCodes
(
false
);
LOG
(
INFO
)
<<
"module:
\n
"
<<
codegen
.
Compile
(
builder
.
Build
(),
CodeGenC
::
OutputKind
::
CImpl
);
}
auto
_ab_bb_cb_
=
CreateTestBuffer
();
// NOLINT
auto
&
ab
=
std
::
get
<
0
>
(
_ab_bb_cb_
);
auto
&
bb
=
std
::
get
<
1
>
(
_ab_bb_cb_
);
auto
&
cb
=
std
::
get
<
2
>
(
_ab_bb_cb_
);
do
{
// call the function
auto
engine
=
backends
::
ExecutionEngine
::
Create
({
1
});
LOG
(
INFO
)
<<
"JIT Link the module"
;
engine
->
Link
(
builder
.
Build
());
auto
cos_fn
=
(
double
(
*
)(
double
))
engine
->
Lookup
(
"cos"
);
LOG
(
INFO
)
<<
"=> LLVM JIT cos(0) = "
<<
cos_fn
(
0
);
auto
elementwise_add_addr
=
engine
->
Lookup
(
"elementwise_add"
);
auto
elementwise_add
=
reinterpret_cast
<
void
(
*
)(
void
*
,
int32_t
)
>
(
elementwise_add_addr
);
LOG
(
INFO
)
<<
"JIT get elementwise_add_addr"
;
break
;
cinn_pod_value_t
a_arg
(
ab
),
b_arg
(
bb
),
c_arg
(
cb
);
cinn_pod_value_t
args
[
3
]
=
{
a_arg
,
b_arg
,
c_arg
};
elementwise_add
(
args
,
3
);
auto
*
ad
=
reinterpret_cast
<
float
*>
(
ab
->
memory
);
auto
*
bd
=
reinterpret_cast
<
float
*>
(
bb
->
memory
);
for
(
int
i
=
0
;
i
<
kM
;
i
++
)
{
for
(
int
j
=
0
;
j
<
kN
;
j
++
)
{
auto
*
data
=
reinterpret_cast
<
float
*>
(
cb
->
memory
);
ASSERT_NEAR
(
data
[
i
*
kN
+
j
],
ad
[
i
*
kN
+
j
]
+
bd
[
i
*
kN
+
j
],
1e-5
);
}
}
}
while
(
false
);
}
TEST
(
ExecutionEngine
,
custom_runtime_symbols
)
{
auto
context
=
std
::
make_unique
<
llvm
::
LLVMContext
>
();
auto
module
=
std
::
make_unique
<
llvm
::
Module
>
(
"test_llvm_cpu_runtime"
,
*
context
);
auto
builder
=
std
::
make_unique
<
llvm
::
IRBuilder
<>>
(
*
context
);
auto
call_custom_target
=
[
&
](
std
::
string
name
,
llvm
::
Type
*
ty
)
{
llvm
::
FunctionType
*
fn_type
=
llvm
::
FunctionType
::
get
(
ty
,
{
ty
},
false
);
llvm
::
Function
*
function
=
llvm
::
Function
::
Create
(
fn_type
,
llvm
::
Function
::
ExternalLinkage
,
"_call_custom_"
+
name
,
module
.
get
());
function
->
setCallingConv
(
llvm
::
CallingConv
::
C
);
llvm
::
BasicBlock
*
entry
=
llvm
::
BasicBlock
::
Create
(
module
->
getContext
(),
"entry"
,
function
);
builder
->
SetInsertPoint
(
entry
);
llvm
::
Argument
*
arg
=
&*
function
->
args
().
begin
();
llvm
::
Function
*
custom_function
=
llvm
::
dyn_cast
<
llvm
::
Function
>
(
module
->
getOrInsertFunction
(
name
,
fn_type
).
getCallee
());
custom_function
->
setCallingConv
(
llvm
::
CallingConv
::
C
);
llvm
::
Value
*
ret
=
builder
->
CreateCall
(
custom_function
,
{
arg
});
builder
->
CreateRet
(
ret
);
};
llvm
::
Type
*
f32
=
builder
->
getFloatTy
();
llvm
::
Type
*
f64
=
builder
->
getDoubleTy
();
call_custom_target
(
"cosf"
,
f32
);
call_custom_target
(
"cos"
,
f64
);
call_custom_target
(
"sinf"
,
f32
);
call_custom_target
(
"sin"
,
f64
);
double
pi
=
std
::
acos
(
-
1
);
std
::
vector
<
double
>
angle
=
{
0.
,
pi
/
6.
,
pi
/
4.
,
pi
/
3.
,
pi
/
2.
,
pi
};
std
::
random_device
rd
;
std
::
mt19937
mt
(
rd
());
std
::
uniform_int_distribution
<
int
>
dis
(
-
100
,
100
);
int
random_x
=
dis
(
mt
);
int
random_y
=
dis
(
mt
);
decltype
(
auto
)
registry
=
GlobalSymbolRegistry
::
Global
();
// registry.Register("dereference_f64_ptr", (void *)+[](double *x) { return
// *x; });
for
(
size_t
i
=
0
;
i
<
angle
.
size
();
i
++
)
{
registry
.
RegisterVar
(
"theta_"
+
std
::
to_string
(
i
),
angle
[
i
]);
}
auto
engine
=
cinn
::
backends
::
ExecutionEngine
::
Create
({
1
});
engine
->
AddModule
(
std
::
move
(
module
),
std
::
move
(
context
));
auto
*
call_cosf
=
reinterpret_cast
<
float
(
*
)(
float
)
>
(
engine
->
Lookup
(
"_call_custom_cosf"
));
auto
*
call_cos
=
reinterpret_cast
<
double
(
*
)(
double
)
>
(
engine
->
Lookup
(
"_call_custom_cos"
));
auto
*
call_sinf
=
reinterpret_cast
<
float
(
*
)(
float
)
>
(
engine
->
Lookup
(
"_call_custom_sinf"
));
auto
*
call_sin
=
reinterpret_cast
<
double
(
*
)(
double
)
>
(
engine
->
Lookup
(
"_call_custom_sin"
));
ASSERT_TRUE
(
call_cosf
&&
call_cos
&&
call_sinf
&&
call_sin
);
for
(
auto
theta
:
angle
)
{
float
theta_f
=
static_cast
<
float
>
(
theta
);
ASSERT_NEAR
(
call_cosf
(
theta_f
),
cosf
(
theta_f
),
1e-6
);
ASSERT_NEAR
(
call_cos
(
theta
),
cos
(
theta
),
1e-6
);
ASSERT_NEAR
(
call_sinf
(
theta_f
),
sinf
(
theta_f
),
1e-6
);
ASSERT_NEAR
(
call_sin
(
theta
),
sin
(
theta
),
1e-6
);
}
}
TEST
(
ExecutionEngine
,
call_extern
)
{
ir
::
Expr
M
(
kM
);
ir
::
Expr
N
(
kN
);
Placeholder
<
float
>
x
(
"x"
,
{
M
,
N
});
Placeholder
<
float
>
y
(
"y"
,
{
M
,
N
});
auto
add_out
=
Compute
(
{
M
,
N
},
[
=
](
Var
i
,
Var
j
)
{
return
x
(
i
,
j
)
+
y
(
i
,
j
);
},
"add_out"
);
ir
::
Tensor
res
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
->
Expr
{
return
lang
::
CallExtern
(
"tanh"
,
{
add_out
(
i
,
j
)});
},
"res"
);
auto
stages
=
CreateStages
({
add_out
,
res
});
stages
[
add_out
]
->
ComputeInline
();
auto
func
=
Lower
(
"comp"
,
stages
,
{
x
,
y
,
res
});
Module
::
Builder
builder
(
"module0"
,
common
::
DefaultHostTarget
());
builder
.
AddFunction
(
func
);
auto
engine
=
backends
::
ExecutionEngine
::
Create
({
1
});
engine
->
Link
(
builder
.
Build
());
auto
_ab_bb_cb_
=
CreateTestBuffer
();
// NOLINT
auto
&
ab
=
std
::
get
<
0
>
(
_ab_bb_cb_
);
auto
&
bb
=
std
::
get
<
1
>
(
_ab_bb_cb_
);
auto
&
cb
=
std
::
get
<
2
>
(
_ab_bb_cb_
);
auto
comp_addr
=
engine
->
Lookup
(
"comp"
);
auto
comp
=
reinterpret_cast
<
void
(
*
)(
void
*
,
int32_t
)
>
(
comp_addr
);
cinn_pod_value_t
a_arg
(
ab
),
b_arg
(
bb
),
c_arg
(
cb
);
cinn_pod_value_t
args
[
3
]
=
{
a_arg
,
b_arg
,
c_arg
};
comp
(
args
,
3
);
auto
*
ad
=
reinterpret_cast
<
float
*>
(
ab
->
memory
);
auto
*
bd
=
reinterpret_cast
<
float
*>
(
bb
->
memory
);
auto
*
cd
=
reinterpret_cast
<
float
*>
(
cb
->
memory
);
for
(
int
m
=
0
;
m
<
kM
;
m
++
)
{
for
(
int
n
=
0
;
n
<
kN
;
n
++
)
{
ASSERT_NEAR
(
cd
[
m
*
kN
+
n
],
tanh
(
ad
[
m
*
kN
+
n
]
+
bd
[
m
*
kN
+
n
]),
1e-5
);
}
}
}
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py
0 → 100644
View file @
992bec46
#!/usr/bin/env python3
# Copyright (c) 2021 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
subprocess
import
sys
def
main
():
path
=
sys
.
argv
[
1
]
out_path
=
sys
.
argv
[
2
]
llvm_config
=
sys
.
argv
[
3
]
srcs
=
[]
srcs
.
append
(
'#include <absl/strings/string_view.h>'
)
# srcs.append('#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"\n')
srcs
.
append
(
'namespace cinn::backends {'
)
srcs
.
append
(
"static const absl::string_view kRuntimeLlvmIr("
)
srcs
.
append
(
'R"ROC('
)
with
open
(
path
,
'r'
)
as
fr
:
srcs
.
append
(
fr
.
read
())
srcs
.
append
(
')ROC"'
)
srcs
.
append
(
');
\n
'
)
cmd
=
f
"
{
llvm_config
}
--version"
version
=
(
subprocess
.
check_output
(
cmd
,
shell
=
True
)
.
decode
(
'utf-8'
)
.
strip
()
.
split
(
'.'
)
)
srcs
.
append
(
"struct llvm_version {"
)
for
v
,
n
in
zip
([
"major"
,
"minor"
,
"micro"
],
version
):
srcs
.
append
(
" static constexpr int k{} = {};"
.
format
(
v
.
title
(),
''
.
join
(
filter
(
str
.
isdigit
,
n
))
)
)
srcs
.
append
(
"};"
)
srcs
.
append
(
'} // namespace cinn::backends'
)
with
open
(
out_path
,
'w'
)
as
fw
:
fw
.
write
(
"
\n
"
.
join
(
srcs
))
def
get_clang_version
():
pass
if
__name__
==
"__main__"
:
main
()
paddle/cinn/backends/llvm/ir_builder_mixin.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Value.h>
#include <utility>
namespace
cinn
{
namespace
backends
{
template
<
typename
Derived
>
class
IrBuilderMixin
{
protected:
template
<
typename
...
Args
>
decltype
(
auto
)
BinOp
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateBinOp
(
std
::
forward
<
Args
>
(
args
)...);
}
/// \brief +
template
<
typename
...
Args
>
decltype
(
auto
)
Add
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateAdd
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FAdd
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFAdd
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
NSWAdd
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateNSWAdd
(
std
::
forward
<
Args
>
(
args
)...);
}
/// \brief -
template
<
typename
...
Args
>
decltype
(
auto
)
Sub
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateSub
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FSub
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFSub
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
NSWSub
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateNSWSub
(
std
::
forward
<
Args
>
(
args
)...);
}
/// \brief *
template
<
typename
...
Args
>
decltype
(
auto
)
Mul
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateMul
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FMul
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFMul
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
NSWMul
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateNSWMul
(
std
::
forward
<
Args
>
(
args
)...);
}
/// \brief /
template
<
typename
...
Args
>
decltype
(
auto
)
SDiv
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateSDiv
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
UDiv
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateUDiv
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FDiv
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFDiv
(
std
::
forward
<
Args
>
(
args
)...);
}
/// \brief %
template
<
typename
...
Args
>
decltype
(
auto
)
SRem
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateSRem
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
URem
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateURem
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FRem
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFRem
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
And
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateAnd
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Or
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateOr
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Not
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateNot
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Neg
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateNeg
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FNeg
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFNeg
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpEQ
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpEQ
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpOEQ
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpOEQ
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpUEQ
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpUEQ
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpNE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpNE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpONE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpONE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpUNE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpUNE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpULE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpULE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpOLE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpOLE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpULT
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpULT
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpSLT
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpSLT
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpOLT
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpOLT
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpUGE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpUGE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpSGE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpSGE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpOGE
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpOGE
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpUGT
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpUGT
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ICmpSGT
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateICmpSGT
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FCmpOGT
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFCmpOGT
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
BitCast
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateBitCast
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
IntCast
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateIntCast
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FPCast
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFPCast
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
PointerCast
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreatePointerCast
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FPToSI
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFPToSI
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
FPToUI
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateFPToUI
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
SIToFP
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateSIToFP
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
UIToFP
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateUIToFP
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Select
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateSelect
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Br
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateBr
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
CondBr
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateCondBr
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Alloca
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateAlloca
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Load
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateLoad
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
AlignedLoad
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateAlignedLoad
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Store
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateStore
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
AlignedStore
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateAlignedStore
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
Call
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateCall
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
RetVoid
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateRetVoid
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
GEP
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateGEP
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
InBoundsGEP
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateInBoundsGEP
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
PHI
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreatePHI
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
InsertValue
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateInsertValue
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ExtractValue
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateExtractValue
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
InsertElement
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateInsertElement
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
decltype
(
auto
)
ShuffleVector
(
Args
&&
...
args
)
{
return
mixin_builder
()
->
CreateShuffleVector
(
std
::
forward
<
Args
>
(
args
)...);
}
private:
llvm
::
IRBuilder
<>
*
mixin_builder
()
{
return
static_cast
<
Derived
*>
(
this
)
->
b
();
}
};
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/llvm_intrin_rule.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <glog/logging.h>
#include <llvm/IR/Intrinsics.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/registry.h"
#include "paddle/cinn/lang/packed_func.h"
namespace
cinn
{
namespace
codegen
{
template
<
int
id
,
int
arg_nums
,
bool
add_float_suffix
=
true
>
inline
void
MakeFloatIntrinOp
(
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg
=
args
[
0
];
ir
::
Call
*
node
=
arg
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK_GE
(
node
->
read_args
.
size
(),
arg_nums
);
if
(
add_float_suffix
)
{
CHECK
(
node
->
type
().
is_float
());
*
rv
=
ir
::
intrinsics
::
BuiltinIntrin
::
Make
(
node
->
name
+
"f"
,
node
->
read_args
,
id
,
arg_nums
,
node
->
type
());
}
else
{
*
rv
=
ir
::
intrinsics
::
BuiltinIntrin
::
Make
(
node
->
name
,
node
->
read_args
,
id
,
arg_nums
,
node
->
type
());
}
}
void
RegisterCpuIntrinRule
()
{
#define __(intrin_name__, id) \
ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true) \
.SetBody(MakeFloatIntrinOp<id, 1>);
__
(
exp
,
::
llvm
::
Intrinsic
::
exp
)
__
(
exp2
,
::
llvm
::
Intrinsic
::
exp2
)
__
(
sqrt
,
::
llvm
::
Intrinsic
::
sqrt
)
__
(
log
,
::
llvm
::
Intrinsic
::
log
)
__
(
log2
,
::
llvm
::
Intrinsic
::
log2
)
__
(
log10
,
::
llvm
::
Intrinsic
::
log10
)
__
(
floor
,
::
llvm
::
Intrinsic
::
floor
)
__
(
ceil
,
::
llvm
::
Intrinsic
::
ceil
)
__
(
round
,
::
llvm
::
Intrinsic
::
round
)
__
(
trunc
,
::
llvm
::
Intrinsic
::
trunc
)
__
(
cos
,
::
llvm
::
Intrinsic
::
cos
)
__
(
sin
,
::
llvm
::
Intrinsic
::
sin
)
__
(
fabs
,
::
llvm
::
Intrinsic
::
fabs
)
#undef __
// set id -1 if not llvm intrinsics
#define RegisterBitwise(intrin_name__) \
ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true) \
.SetBody(MakeFloatIntrinOp<-1, 2, false>);
RegisterBitwise
(
bitwise_or
)
RegisterBitwise
(
bitwise_xor
)
RegisterBitwise
(
bitwise_and
)
RegisterBitwise
(
left_shift
)
RegisterBitwise
(
right_shift
)
#undef RegisterBitwise
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_fma"
,
true
)
.
SetBody
(
MakeFloatIntrinOp
<::
llvm
::
Intrinsic
::
fmuladd
,
3
,
false
>
);
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_bitwise_not"
,
true
)
.
SetBody
(
MakeFloatIntrinOp
<-
1
,
1
,
false
>
);
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_isnan"
,
true
)
.
SetBody
(
MakeFloatIntrinOp
<-
1
,
1
,
false
>
);
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_isfinite"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
*
rv
=
!
(
lang
::
IsInf
(
arg
))
&&
!
(
lang
::
IsNan
(
arg
));
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_isinf"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
Type
type
=
arg
->
type
();
if
(
type
.
is_int
()
||
type
.
is_uint
())
{
*
rv
=
common
::
make_bool
(
false
,
type
.
lanes
());
}
else
if
(
type
.
is_float
())
{
*
rv
=
ir
::
EQ
::
Make
(
lang
::
Abs
(
arg
),
lang
::
Infinity
(
type
))
&&
!
(
lang
::
IsNan
(
arg
));
}
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_rsqrt"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
*
rv
=
make_const
(
arg
->
type
(),
1
)
/
lang
::
Sqrt
(
arg
);
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_exp10"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
Expr
ln10
=
make_const
(
arg
->
type
(),
2.302585093
);
*
rv
=
lang
::
Exp
(
arg
*
ln10
);
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_tan"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
*
rv
=
lang
::
Sin
(
arg
)
/
lang
::
Cos
(
arg
);
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_tanh"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
Expr
zero
=
make_const
(
arg
->
type
(),
0
);
Expr
one
=
make_const
(
arg
->
type
(),
1
);
Expr
two
=
make_const
(
arg
->
type
(),
2
);
Expr
neg_two
=
make_const
(
arg
->
type
(),
-
2
);
Expr
exp_neg2x
=
lang
::
Exp
(
neg_two
*
arg
);
Expr
exp_pos2x
=
lang
::
Exp
(
two
*
arg
);
Expr
tanh_pos
=
(
one
-
exp_neg2x
)
/
(
one
+
exp_neg2x
);
Expr
tanh_neg
=
(
exp_pos2x
-
one
)
/
(
exp_pos2x
+
one
);
*
rv
=
ir
::
Select
::
Make
(
arg
>=
zero
,
tanh_pos
,
tanh_neg
);
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_cosh"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
*
rv
=
(
lang
::
Exp
(
arg
)
+
lang
::
Exp
(
arg
*
make_const
(
arg
->
type
(),
-
1
)))
/
make_const
(
arg
->
type
(),
2
);
});
ir
::
Registry
::
Register
(
"lower_cpu_intrinsic_sinh"
,
true
)
.
SetBody
([](
lang
::
Args
args
,
lang
::
RetValue
*
rv
)
{
CHECK_GE
(
args
.
size
(),
1U
);
Expr
arg0
=
args
[
0
];
ir
::
Call
*
node
=
arg0
->
as
<
ir
::
Call
>
();
CHECK
(
node
);
CHECK
(
!
node
->
read_args
.
empty
());
Expr
arg
=
node
->
read_args
[
0
];
*
rv
=
(
lang
::
Exp
(
arg
)
-
lang
::
Exp
(
arg
*
make_const
(
arg
->
type
(),
-
1
)))
/
make_const
(
arg
->
type
(),
2
);
});
}
}
// namespace codegen
}
// namespace cinn
paddle/cinn/backends/llvm/llvm_optimizer.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/llvm_optimizer.h"
#include <glog/logging.h>
#include <llvm/ADT/Triple.h>
#include <llvm/Analysis/CGSCCPassManager.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/NewGVN.h>
#include <llvm/Transforms/Scalar/Reassociate.h>
#include <llvm/Transforms/Scalar/SimplifyCFG.h>
#include <llvm/Transforms/Vectorize.h>
#include <algorithm>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "llvm/Support/CodeGen.h"
namespace
cinn
::
backends
{
namespace
{
template
<
typename
PassManagerT
>
class
CustomPassManager
:
public
PassManagerT
{
public:
template
<
typename
...
Ts
>
explicit
CustomPassManager
(
bool
print_passes
,
Ts
&&
...
ts
)
:
PassManagerT
(
std
::
forward
<
Ts
>
(
ts
)...),
print_passes_
(
print_passes
)
{}
void
add
(
llvm
::
Pass
*
pass
)
override
{
if
(
print_passes_
)
{
if
(
is_function_pass_manager_
)
{
VLOG
(
1
)
<<
"llvm run function pass["
<<
std
::
string
(
pass
->
getPassName
())
<<
"]"
;
}
if
(
is_module_pass_manager_
)
{
VLOG
(
1
)
<<
"llvm run module pass["
<<
std
::
string
(
pass
->
getPassName
())
<<
"]"
;
}
}
// static bool add_pass = true;
// if (add_pass) {
// PassManagerT::add(pass);
//}
// if (std::string(pass->getPassName()) == "Loop Vectorization") {
// return;
//}
PassManagerT
::
add
(
pass
);
}
void
run
(
llvm
::
Function
&
f
)
{
// NOLINT
if
(
is_function_pass_manager_
)
{
PassManagerT
::
run
(
f
);
}
}
void
run
(
llvm
::
Module
&
m
)
{
// NOLINT
if
(
is_module_pass_manager_
)
{
PassManagerT
::
run
(
m
);
}
}
private:
static
constexpr
bool
is_function_pass_manager_
=
std
::
is_same
<
llvm
::
legacy
::
FunctionPassManager
,
PassManagerT
>::
value
;
static
constexpr
bool
is_module_pass_manager_
=
std
::
is_same
<
llvm
::
legacy
::
PassManager
,
PassManagerT
>::
value
;
bool
print_passes_
;
};
using
CustomFunctionPassManager
=
CustomPassManager
<
llvm
::
legacy
::
FunctionPassManager
>
;
using
CustomModulePassManager
=
CustomPassManager
<
llvm
::
legacy
::
PassManager
>
;
}
// namespace
LLVMModuleOptimizer
::
LLVMModuleOptimizer
(
llvm
::
TargetMachine
*
machine
,
int
opt_level
,
llvm
::
FastMathFlags
fast_math_flags
,
bool
print_passes
)
:
opt_level_
(
opt_level
),
print_passes_
(
print_passes
),
machine_
(
machine
)
{}
void
LLVMModuleOptimizer
::
operator
()(
llvm
::
Module
*
m
)
{
auto
machine
=
std
::
move
(
llvm
::
cantFail
(
llvm
::
cantFail
(
llvm
::
orc
::
JITTargetMachineBuilder
::
detectHost
())
.
createTargetMachine
()));
auto
fpm
=
std
::
make_unique
<
CustomFunctionPassManager
>
(
print_passes_
,
m
);
// fpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis()));
// fpm->add(llvm::createInstructionCombiningPass());
// fpm->add(llvm::createReassociatePass());
// fpm->add(llvm::createGVNPass());
// fpm->add(llvm::createCFGSimplificationPass());
// fpm->add(llvm::createSROAPass());
// fpm->add(llvm::createEarlyCSEPass());
// fpm->add(llvm::createLowerExpectIntrinsicPass());
// fpm->add(llvm::createCallSiteSplittingPass());
// fpm->add(llvm::createLoopVectorizePass());
// fpm->add(llvm::createSLPVectorizerPass());
// fpm->add(llvm::createLoadStoreVectorizerPass());
// fpm->add(llvm::createLoopUnrollPass());
auto
mpm
=
std
::
make_unique
<
CustomModulePassManager
>
(
print_passes_
);
// mpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis()));
// LOG(INFO) << "llvm run pass: target machine: name[" <<
// machine_->getTarget().getName() << "]"; LOG(INFO) << "llvm run pass: target
// machine: cpu[" << machine_->getTargetCPU().str() << "]";
fpm
->
add
(
llvm
::
createTargetTransformInfoWrapperPass
(
machine
->
getTargetIRAnalysis
()));
mpm
->
add
(
llvm
::
createTargetTransformInfoWrapperPass
(
machine
->
getTargetIRAnalysis
()));
auto
builder
=
std
::
make_unique
<
llvm
::
PassManagerBuilder
>
();
builder
->
OptLevel
=
opt_level_
;
builder
->
Inliner
=
llvm
::
createFunctionInliningPass
();
builder
->
LoopVectorize
=
true
;
builder
->
SLPVectorize
=
true
;
#if LLVM_VERSION_MAJOR >= 11
machine
->
adjustPassManager
(
*
builder
);
#endif
builder
->
populateFunctionPassManager
(
*
fpm
);
builder
->
populateModulePassManager
(
*
mpm
);
fpm
->
doInitialization
();
std
::
for_each
(
m
->
begin
(),
m
->
end
(),
[
&
fpm
](
auto
&
fn
)
{
fpm
->
run
(
fn
);
});
fpm
->
doFinalization
();
mpm
->
run
(
*
m
);
}
}
// namespace cinn::backends
paddle/cinn/backends/llvm/llvm_optimizer.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/IR/Instruction.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/Pass.h>
#include <llvm/Target/TargetMachine.h>
#include <functional>
namespace
cinn
::
backends
{
// TODO(fc500110): define class OptimizeOptions
// llvm module optimizer
class
LLVMModuleOptimizer
final
{
public:
explicit
LLVMModuleOptimizer
(
llvm
::
TargetMachine
*
machine
,
int
opt_level
,
llvm
::
FastMathFlags
fast_math_flags
,
bool
print_passes
=
false
);
void
operator
()(
llvm
::
Module
*
m
);
private:
llvm
::
TargetMachine
*
machine_
;
int
opt_level_
{};
bool
print_passes_
{};
};
}
// namespace cinn::backends
paddle/cinn/backends/llvm/llvm_util.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include <glog/logging.h>
#include <llvm/Support/Alignment.h>
#include <atomic>
#include <mutex> //NOLINT
namespace
cinn
{
namespace
backends
{
using
cinn
::
common
::
bfloat16
;
using
cinn
::
common
::
float16
;
llvm
::
Type
*
CinnTypeToLLVMType
(
common
::
Type
type
,
llvm
::
Module
*
m
,
bool
is_vec
)
{
llvm
::
Type
*
ir_type
=
nullptr
;
if
(
type
.
is_cpp_const
())
{
// TODO(fc500110) support it latter.
}
llvm
::
Type
*
v
=
llvm
::
Type
::
getVoidTy
(
m
->
getContext
());
llvm
::
Type
*
i1
=
llvm
::
Type
::
getInt1Ty
(
m
->
getContext
());
llvm
::
Type
*
i8
=
llvm
::
Type
::
getInt8Ty
(
m
->
getContext
());
llvm
::
Type
*
i16
=
llvm
::
Type
::
getInt16Ty
(
m
->
getContext
());
llvm
::
Type
*
i32
=
llvm
::
Type
::
getInt32Ty
(
m
->
getContext
());
llvm
::
Type
*
i64
=
llvm
::
Type
::
getInt64Ty
(
m
->
getContext
());
llvm
::
Type
*
u8
=
llvm
::
Type
::
getInt8Ty
(
m
->
getContext
());
llvm
::
Type
*
u16
=
llvm
::
Type
::
getInt16Ty
(
m
->
getContext
());
llvm
::
Type
*
u32
=
llvm
::
Type
::
getInt32Ty
(
m
->
getContext
());
llvm
::
Type
*
u64
=
llvm
::
Type
::
getInt64Ty
(
m
->
getContext
());
llvm
::
Type
*
bf16
=
llvm
::
Type
::
getBFloatTy
(
m
->
getContext
());
llvm
::
Type
*
f16
=
llvm
::
Type
::
getHalfTy
(
m
->
getContext
());
llvm
::
Type
*
f32
=
llvm
::
Type
::
getFloatTy
(
m
->
getContext
());
llvm
::
Type
*
f64
=
llvm
::
Type
::
getDoubleTy
(
m
->
getContext
());
llvm
::
Type
*
arr
=
llvm
::
Type
::
getPrimitiveType
(
m
->
getContext
(),
llvm
::
Type
::
ArrayTyID
);
if
(
type
.
is_void
()
&&
type
.
is_cpp_handle
())
{
return
llvm
::
PointerType
::
getUnqual
(
i8
);
}
if
(
type
.
is_void
()
&&
type
.
is_cpp_handle2
())
{
return
llvm
::
PointerType
::
getUnqual
(
llvm
::
PointerType
::
getUnqual
(
i8
));
}
if
(
type
.
is_bool
())
{
ir_type
=
i1
;
}
else
if
(
type
.
is_int
(
8
))
{
ir_type
=
i8
;
}
else
if
(
type
.
is_int
(
16
))
{
ir_type
=
i16
;
}
else
if
(
type
.
is_int
(
32
))
{
ir_type
=
i32
;
}
else
if
(
type
.
is_int
(
64
))
{
ir_type
=
i64
;
}
else
if
(
type
.
is_uint
(
8
))
{
ir_type
=
u8
;
}
else
if
(
type
.
is_uint
(
16
))
{
ir_type
=
u16
;
}
else
if
(
type
.
is_uint
(
32
))
{
ir_type
=
u32
;
}
else
if
(
type
.
is_uint
(
64
))
{
ir_type
=
u64
;
}
else
if
(
type
.
is_float
(
32
))
{
ir_type
=
f32
;
}
else
if
(
type
.
is_float
(
64
))
{
ir_type
=
f64
;
}
else
if
(
type
.
is_bfloat16
())
{
ir_type
=
bf16
;
}
else
if
(
type
.
is_float16
())
{
ir_type
=
f16
;
}
else
if
(
type
.
is_void
())
{
ir_type
=
v
;
}
else
if
(
type
.
is_string
())
{
ir_type
=
arr
;
}
else
if
(
type
.
is_customized_type
())
{
CHECK
(
!
type
.
customized_type
().
empty
());
ir_type
=
m
->
getTypeByName
(
"struct."
+
type
.
customized_type
());
}
CHECK
(
ir_type
)
<<
"LLVM can't convert type: "
<<
type
;
// C array / vector.
if
(
type
.
lanes
()
>
1
)
{
if
(
is_vec
)
{
ir_type
=
llvm
::
FixedVectorType
::
get
(
ir_type
,
type
.
lanes
());
}
else
{
ir_type
=
llvm
::
ArrayType
::
get
(
ir_type
,
type
.
lanes
());
}
}
if
(
type
.
is_cpp_handle
())
{
ir_type
=
llvm
::
PointerType
::
getUnqual
(
ir_type
);
}
if
(
type
.
is_cpp_handle2
())
{
ir_type
=
llvm
::
PointerType
::
getUnqual
(
ir_type
);
ir_type
=
llvm
::
PointerType
::
getUnqual
(
ir_type
);
}
return
ir_type
;
}
#define __(ty__) \
template <> \
llvm::Type *llvm_type_of<ty__>(llvm::Module * m) { \
return CinnTypeToLLVMType(common::type_of<ty__>(), m); \
}
__
(
int8_t
)
__
(
int16_t
)
__
(
int32_t
)
__
(
int64_t
)
__
(
uint8_t
)
__
(
uint16_t
)
__
(
uint32_t
)
__
(
uint64_t
)
__
(
bfloat16
)
__
(
float16
)
__
(
float
)
__
(
double
)
__
(
cinn_buffer_t
)
__
(
cinn_buffer_t
*
)
__
(
cinn_pod_value_t
*
)
__
(
cinn_pod_value_t
)
__
(
void
*
)
__
(
void
**
)
#undef __
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/llvm_util.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instruction.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Value.h>
#include <string>
#include <type_traits>
#include <utility>
#include "paddle/cinn/common/type.h"
namespace
cinn
{
namespace
backends
{
template
<
typename
T
>
std
::
string
DumpToString
(
const
T
&
entity
)
{
std
::
string
buffer
;
llvm
::
raw_string_ostream
os
(
buffer
);
entity
.
print
(
os
);
os
.
flush
();
return
buffer
;
// return "\033[33m" + buffer + "\033[0m"; // Green
}
inline
llvm
::
StringRef
AsStringRef
(
absl
::
string_view
str
)
{
return
llvm
::
StringRef
(
str
.
data
(),
str
.
size
());
}
llvm
::
Type
*
CinnTypeToLLVMType
(
common
::
Type
t
,
llvm
::
Module
*
m
,
bool
is_vec
=
false
);
template
<
typename
T
>
llvm
::
Type
*
llvm_type_of
(
llvm
::
Module
*
m
);
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/runtime_symbol_registry.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include <absl/strings/string_view.h>
#include <glog/raw_logging.h>
#include <iostream>
#include "gflags/gflags_declare.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool
(
verbose_function_register
);
namespace
cinn
{
namespace
backends
{
RuntimeSymbols
&
GlobalSymbolRegistry
::
Global
()
{
static
RuntimeSymbols
symbols
;
return
symbols
;
}
void
*
RuntimeSymbols
::
Lookup
(
absl
::
string_view
name
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
auto
it
=
symbols_
.
find
(
std
::
string
(
name
));
if
(
it
!=
symbols_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
}
void
RuntimeSymbols
::
Register
(
const
std
::
string
&
name
,
void
*
address
)
{
#ifdef CINN_WITH_DEBUG
if
(
FLAGS_verbose_function_register
)
{
RAW_LOG_INFO
(
"JIT Register function [%s]: %p"
,
name
.
c_str
(),
address
);
}
#endif // CINN_WITH_DEBUG
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
auto
it
=
symbols_
.
find
(
name
);
if
(
it
!=
symbols_
.
end
())
{
CHECK_EQ
(
it
->
second
,
address
)
<<
"Duplicate register symbol ["
<<
name
<<
"]"
;
return
;
}
symbols_
.
insert
({
name
,
reinterpret_cast
<
void
*>
(
address
)});
}
void
RuntimeSymbols
::
Clear
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
symbols_
.
clear
();
scalar_holder_
.
clear
();
}
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/runtime_symbol_registry.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include <absl/types/any.h>
#include <absl/types/variant.h>
#include <glog/logging.h>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "paddle/cinn/common/macros.h"
namespace
cinn
{
namespace
backends
{
class
RuntimeSymbols
{
public:
RuntimeSymbols
()
=
default
;
RuntimeSymbols
(
const
RuntimeSymbols
&
)
=
delete
;
RuntimeSymbols
(
RuntimeSymbols
&&
rhs
)
{
symbols_
=
std
::
move
(
rhs
.
symbols_
);
scalar_holder_
=
std
::
move
(
rhs
.
scalar_holder_
);
}
/**
* Register function address.
* @param name Name of the symbol.
* @param address Address of the function.
*/
void
RegisterFn
(
const
std
::
string
&
name
,
void
*
address
)
{
Register
(
name
,
address
);
}
/**
* Register scalar.
* @tparam T Type of the scalar.
* @param name Name of the symbol.
* @param val Scalar value.
*/
template
<
typename
T
,
typename
=
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>>
void
RegisterVar
(
const
std
::
string
&
name
,
T
val
)
{
void
*
data_ptr
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
auto
&
data
=
scalar_holder_
[
name
];
data
.
resize
(
sizeof
(
T
));
memcpy
(
data
.
data
(),
&
val
,
sizeof
(
T
));
data_ptr
=
reinterpret_cast
<
void
*>
(
data
.
data
());
}
Register
(
name
,
data_ptr
);
}
/**
* Lookup a symbol from the registry.
* @param name Name of the symbol.
* @return The address if existes, or nullptr will return.
*/
void
*
Lookup
(
absl
::
string_view
name
)
const
;
/**
* Get all the symbols.
*/
const
std
::
map
<
std
::
string
,
void
*>
&
All
()
const
{
return
symbols_
;
}
/**
* Clear all the symbols.
*/
void
Clear
();
private:
/**
* Register external symbol to the registry, the symbols in the registry will
* finally registered to JIT .
* @param name Name of the symbol in the JIT.
* @param address The address of the variable in external space.
*/
void
Register
(
const
std
::
string
&
name
,
void
*
address
);
mutable
std
::
mutex
mu_
;
std
::
map
<
std
::
string
,
void
*>
symbols_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int8_t
>>
scalar_holder_
;
};
/**
* Registry for runtime symbols, these symbols will be inserted into JIT.
*/
class
GlobalSymbolRegistry
{
public:
static
RuntimeSymbols
&
Global
();
private:
GlobalSymbolRegistry
()
=
default
;
CINN_DISALLOW_COPY_AND_ASSIGN
(
GlobalSymbolRegistry
);
};
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/simple_jit.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/llvm/simple_jit.h"
#include <llvm/AsmParser/Parser.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/Core.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/Scalar/Reassociate.h>
#include <llvm/Transforms/Scalar/SimplifyCFG.h>
#include <string>
#include <utility>
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/llvm/cinn_runtime_llvm_ir.h"
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace
cinn
{
namespace
backends
{
void
SimpleJIT
::
AddModule
(
std
::
unique_ptr
<
llvm
::
Module
>
module
,
bool
optimize
)
{
/*
for (auto &fn : module->functions()) {
LOG(INFO) << "fn:\n" << DumpToString(fn);
}
*/
CHECK
(
!
llvm
::
verifyModule
(
*
module
,
&
llvm
::
errs
()))
<<
"Transformation resulted in an invalid module
\n\n
module:
\n
"
;
bool
debug
=
false
;
if
(
optimize
)
{
llvm
::
PassBuilder
pass_builder
;
llvm
::
LoopAnalysisManager
loop_analysis_manager
(
debug
);
llvm
::
FunctionAnalysisManager
function_analysis_manager
(
debug
);
llvm
::
CGSCCAnalysisManager
cgscc_analysis_manager
(
debug
);
llvm
::
ModuleAnalysisManager
module_analysis_manager
(
debug
);
pass_builder
.
registerModuleAnalyses
(
module_analysis_manager
);
pass_builder
.
registerCGSCCAnalyses
(
cgscc_analysis_manager
);
pass_builder
.
registerFunctionAnalyses
(
function_analysis_manager
);
pass_builder
.
registerLoopAnalyses
(
loop_analysis_manager
);
pass_builder
.
crossRegisterProxies
(
loop_analysis_manager
,
function_analysis_manager
,
cgscc_analysis_manager
,
module_analysis_manager
);
llvm
::
ModulePassManager
module_pass_manager
=
pass_builder
.
buildPerModuleDefaultPipeline
(
llvm
::
PassBuilder
::
OptimizationLevel
::
O3
);
module_pass_manager
.
run
(
*
module
,
module_analysis_manager
);
}
VLOG
(
3
)
<<
"jit target: "
<<
jit_
->
getDataLayout
().
getStringRepresentation
();
VLOG
(
3
)
<<
"module target: "
<<
module
->
getDataLayout
().
getStringRepresentation
();
llvm
::
orc
::
ThreadSafeModule
tsm
(
std
::
move
(
module
),
context_
);
llvm
::
cantFail
(
jit_
->
addIRModule
(
std
::
move
(
tsm
)));
if
(
debug
)
{
std
::
string
buffer
;
llvm
::
raw_string_ostream
os
(
buffer
);
jit_
->
getExecutionSession
().
dump
(
os
);
os
.
flush
();
VLOG
(
3
)
<<
"compiled jit:
\n
"
<<
buffer
;
}
}
SimpleJIT
::
SimpleJIT
()
:
context_
(
std
::
make_unique
<
llvm
::
LLVMContext
>
())
{
llvm
::
InitializeAllTargetInfos
();
llvm
::
InitializeAllTargets
();
llvm
::
InitializeAllTargetMCs
();
llvm
::
InitializeAllAsmParsers
();
llvm
::
InitializeAllAsmPrinters
();
jit_
=
llvm
::
cantFail
(
llvm
::
orc
::
LLJITBuilder
().
create
());
CHECK
(
jit_
)
<<
"JIT create failed"
;
auto
proc_symbols_generator
=
llvm
::
cantFail
(
llvm
::
orc
::
DynamicLibrarySearchGenerator
::
GetForCurrentProcess
(
jit_
->
getDataLayout
().
getGlobalPrefix
()));
jit_
->
getMainJITDylib
().
addGenerator
(
std
::
move
(
proc_symbols_generator
));
llvm
::
orc
::
MangleAndInterner
mangle
(
jit_
->
getExecutionSession
(),
jit_
->
getDataLayout
());
for
(
auto
&
item
:
GlobalSymbolRegistry
::
Global
().
All
())
{
VLOG
(
2
)
<<
"Insert ["
<<
item
.
first
<<
"] to SimpleJIT"
;
llvm
::
cantFail
(
jit_
->
define
(
llvm
::
orc
::
absoluteSymbols
(
{{
mangle
(
item
.
first
),
{
llvm
::
pointerToJITTargetAddress
(
item
.
second
),
llvm
::
JITSymbolFlags
::
None
}}})));
}
}
template
<
typename
CodeGenT
>
void
SimpleJIT
::
Link
(
ir
::
Module
module
,
bool
optimize
)
{
std
::
string
runtime_ir
(
backends
::
kRuntimeLlvmIr
);
llvm
::
SMDiagnostic
error
;
auto
m
=
llvm
::
parseAssemblyString
(
runtime_ir
,
error
,
context
());
m
->
setDataLayout
(
jit_
->
getDataLayout
());
auto
b
=
std
::
make_unique
<
llvm
::
IRBuilder
<>>
(
context
());
auto
ir_emitter
=
std
::
make_unique
<
CodeGenT
>
(
m
.
get
(),
b
.
get
());
ir_emitter
->
Compile
(
module
);
CHECK
(
!
llvm
::
verifyModule
(
*
m
,
&
llvm
::
errs
()))
<<
"Invalid module found"
;
AddModule
(
std
::
move
(
m
),
optimize
);
}
template
void
SimpleJIT
::
Link
<
CodeGenLLVM
>(
ir
::
Module
module
,
bool
optimize
);
template
void
SimpleJIT
::
Link
<
CodeGenCUDA_Host
>(
ir
::
Module
module
,
bool
optimize
);
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/llvm/simple_jit.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/strings/string_view.h>
#include <llvm/AsmParser/Parser.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SmallVectorMemoryBuffer.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Support/raw_ostream.h>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/llvm_util.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace
cinn
{
namespace
backends
{
class
SimpleJIT
{
public:
static
std
::
unique_ptr
<
SimpleJIT
>
Create
()
{
return
std
::
unique_ptr
<
SimpleJIT
>
(
new
SimpleJIT
);
}
/**
* Runtime link to a module.
* @tparam CodeGenT a CodeGenLLVM implementation.
* @param module a CINN module.
* @param optimize whether to optimize.
*/
template
<
typename
CodeGenT
=
CodeGenLLVM
>
void
Link
(
ir
::
Module
module
,
bool
optimize
=
true
);
void
Link
(
llvm
::
orc
::
ThreadSafeModule
m
,
bool
optimize
=
true
)
{
llvm
::
cantFail
(
jit_
->
addIRModule
(
std
::
move
(
m
)));
}
llvm
::
JITTargetAddress
Lookup
(
absl
::
string_view
name
)
{
return
llvm
::
cantFail
(
jit_
->
lookup
(
AsStringRef
(
name
))).
getAddress
();
}
private:
void
AddModule
(
std
::
unique_ptr
<
llvm
::
Module
>
module
,
bool
optimize
);
llvm
::
LLVMContext
&
context
()
{
return
*
context_
.
getContext
();
}
SimpleJIT
();
std
::
unique_ptr
<
llvm
::
orc
::
LLJIT
>
jit_
;
llvm
::
orc
::
ThreadSafeContext
context_
;
};
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/modular.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/modular.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace
cinn
{
namespace
backends
{
class
ModularEvaluator
:
public
ir
::
IRVisitorRequireReImpl
<
ModularEntry
>
{
public:
explicit
ModularEvaluator
(
const
std
::
map
<
Var
,
ModularEntry
>&
mod_map
)
:
mod_map_
(
mod_map
)
{}
ModularEntry
Eval
(
const
Expr
&
e
)
{
return
ir
::
IRVisitorRequireReImpl
<
ModularEntry
>::
Visit
(
&
e
);
}
ModularEntry
Visit
(
const
ir
::
IntImm
*
op
)
{
if
(
op
->
value
<
std
::
numeric_limits
<
int
>::
max
())
{
return
ModularEntry
{
static_cast
<
int
>
(
op
->
value
),
0
};
}
return
ModularEntry
::
everything
();
}
ModularEntry
Visit
(
const
ir
::
UIntImm
*
op
)
{
if
(
op
->
value
<
std
::
numeric_limits
<
uint64_t
>::
max
())
{
return
ModularEntry
{
static_cast
<
int
>
(
op
->
value
),
0
};
}
return
ModularEntry
::
everything
();
}
ModularEntry
Visit
(
const
ir
::
_Var_
*
op
)
{
Var
var
(
&
Reference
(
op
));
auto
it
=
mod_map_
.
find
(
var
);
if
(
it
!=
mod_map_
.
end
())
return
it
->
second
;
return
ModularEntry
::
everything
();
}
ModularEntry
Visit
(
const
ir
::
Add
*
op
)
{
auto
a
=
Eval
(
op
->
a
());
auto
b
=
Eval
(
op
->
b
());
ModularEntry
ret
;
ret
.
coeff
=
gcd
(
a
.
coeff
,
b
.
coeff
);
ret
.
base
=
BaseSimplify
(
a
.
base
+
b
.
base
,
ret
.
coeff
);
return
ret
;
}
ModularEntry
Visit
(
const
ir
::
Sub
*
op
)
{
auto
a
=
Eval
(
op
->
a
());
auto
b
=
Eval
(
op
->
b
());
ModularEntry
ret
;
ret
.
coeff
=
gcd
(
a
.
coeff
,
b
.
coeff
);
ret
.
base
=
BaseSimplify
(
a
.
base
-
b
.
base
,
ret
.
coeff
);
return
ret
;
}
ModularEntry
Visit
(
const
ir
::
Mul
*
op
)
{
auto
a
=
Eval
(
op
->
a
());
auto
b
=
Eval
(
op
->
b
());
int
pq
=
a
.
coeff
*
b
.
coeff
;
int
pm
=
a
.
coeff
*
b
.
base
;
int
qn
=
a
.
base
*
b
.
coeff
;
ModularEntry
ret
;
ret
.
coeff
=
gcd
(
pq
,
gcd
(
pm
,
qn
));
ret
.
base
=
BaseSimplify
(
a
.
base
*
b
.
base
,
ret
.
coeff
);
return
ret
;
}
ModularEntry
Visit
(
const
ir
::
Div
*
op
)
{
auto
a
=
Eval
(
op
->
a
());
auto
b
=
Eval
(
op
->
b
());
if
(
b
.
coeff
%
b
.
base
==
0
)
{
ModularEntry
ret
;
ret
.
coeff
=
a
.
coeff
/
b
.
base
;
ret
.
base
=
0
;
return
ret
;
}
return
ModularEntry
::
everything
();
}
static
int
BaseSimplify
(
int
base
,
int
coeff
)
{
if
(
coeff
==
0
)
return
base
;
base
=
base
%
coeff
;
if
(
base
<
0
)
base
+=
coeff
;
return
base
;
}
static
int
gcd
(
int
a
,
int
b
)
{
CHECK_GE
(
a
,
0
);
CHECK_GE
(
b
,
0
);
if
(
a
<
b
)
std
::
swap
(
a
,
b
);
if
(
b
==
0
)
return
a
;
while
(
a
%
b
!=
0
)
{
a
=
a
%
b
;
std
::
swap
(
a
,
b
);
}
return
b
;
}
private:
const
std
::
map
<
Var
,
ModularEntry
>&
mod_map_
;
};
ModularEntry
ModularEntry
::
Add
(
const
ModularEntry
&
a
,
const
ModularEntry
&
b
)
{
ModularEntry
ret
;
ret
.
coeff
=
ModularEvaluator
::
gcd
(
a
.
coeff
,
b
.
coeff
);
ret
.
base
=
ModularEvaluator
::
BaseSimplify
(
a
.
base
+
b
.
base
,
ret
.
coeff
);
return
ret
;
}
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/modular.h
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include "paddle/cinn/ir/ir.h"
namespace
cinn
{
namespace
backends
{
// borrowed from Halide and TVM.
struct
ModularEntry
{
int
base
;
int
coeff
;
ModularEntry
()
=
default
;
ModularEntry
(
int
base
,
int
coeff
)
:
base
(
base
),
coeff
(
coeff
)
{}
static
ModularEntry
everything
()
{
return
ModularEntry
{
0
,
1
};
}
static
ModularEntry
Add
(
const
ModularEntry
&
a
,
const
ModularEntry
&
b
);
};
ModularEntry
EvalModular
(
const
Expr
&
e
,
const
std
::
map
<
Var
,
ModularEntry
>&
mod_map
);
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/nvrtc/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS header_generator.cc nvrtc_util.cc
)
cinn_nv_test
(
test_nvrtc_util SRCS nvrtc_util_test.cc DEPS cinncore
)
paddle/cinn/backends/nvrtc/header_generator.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/nvrtc/header_generator.h"
#include "glog/logging.h"
#include "jitify.hpp" // NOLINT
namespace
cinn
{
namespace
backends
{
namespace
nvrtc
{
HeaderGeneratorBase
&
JitSafeHeaderGenerator
::
GetInstance
()
{
static
JitSafeHeaderGenerator
instance
;
return
instance
;
}
const
size_t
JitSafeHeaderGenerator
::
size
()
const
{
CHECK_EQ
(
include_names_
.
size
(),
headers_
.
size
())
<<
"Internal error in size of header files."
;
return
include_names_
.
size
();
}
JitSafeHeaderGenerator
::
JitSafeHeaderGenerator
()
{
const
auto
&
headers_map
=
::
jitify
::
detail
::
get_jitsafe_headers_map
();
for
(
auto
&
pair
:
headers_map
)
{
include_names_
.
emplace_back
(
pair
.
first
.
data
());
headers_
.
emplace_back
(
pair
.
second
.
data
());
}
}
}
// namespace nvrtc
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/nvrtc/header_generator.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
namespace
cinn
{
namespace
backends
{
class
HeaderGeneratorBase
{
public:
virtual
const
size_t
size
()
const
=
0
;
virtual
const
std
::
vector
<
const
char
*>&
headers
()
const
=
0
;
virtual
const
std
::
vector
<
const
char
*>&
include_names
()
const
=
0
;
};
namespace
nvrtc
{
class
JitSafeHeaderGenerator
:
public
HeaderGeneratorBase
{
public:
static
HeaderGeneratorBase
&
GetInstance
();
const
size_t
size
()
const
;
const
std
::
vector
<
const
char
*>&
headers
()
const
override
{
return
headers_
;
}
const
std
::
vector
<
const
char
*>&
include_names
()
const
override
{
return
include_names_
;
}
private:
JitSafeHeaderGenerator
();
std
::
vector
<
const
char
*>
headers_
;
std
::
vector
<
const
char
*>
include_names_
;
};
}
// namespace nvrtc
}
// namespace backends
}
// namespace cinn
paddle/cinn/backends/nvrtc/nvrtc_util.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <nvrtc.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/nvrtc/header_generator.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string
(
cinn_nvcc_cmd_path
);
DECLARE_bool
(
nvrtc_compile_to_cubin
);
namespace
cinn
{
namespace
backends
{
namespace
nvrtc
{
std
::
string
Compiler
::
operator
()(
const
std
::
string
&
code
,
bool
include_headers
)
{
if
(
runtime
::
CanUseNvccCompiler
())
{
return
CompileWithNvcc
(
code
);
}
return
CompileCudaSource
(
code
,
include_headers
);
}
Compiler
::
Compiler
()
{
if
(
FLAGS_nvrtc_compile_to_cubin
)
{
#if CUDA_VERSION >= 11010
compile_to_cubin_
=
true
;
#endif
}
VLOG
(
4
)
<<
"FLAGS_nvrtc_compile_to_cubin: "
<<
FLAGS_nvrtc_compile_to_cubin
<<
", compile_to_cubin_: "
<<
compile_to_cubin_
;
}
bool
Compiler
::
compile_to_cubin
()
{
return
compile_to_cubin_
;
}
std
::
vector
<
std
::
string
>
Compiler
::
FindCUDAIncludePaths
()
{
const
std
::
string
delimiter
=
"/"
;
std
::
string
cuda_include_path
;
const
char
*
cuda_path_env
=
std
::
getenv
(
"CUDA_PATH"
);
if
(
cuda_path_env
!=
nullptr
)
{
cuda_include_path
+=
cuda_path_env
;
cuda_include_path
+=
delimiter
+
"include"
;
return
{
cuda_include_path
};
}
#if defined(__linux__)
struct
stat
st
;
cuda_include_path
=
"/usr/local/cuda/include"
;
if
(
stat
(
cuda_include_path
.
c_str
(),
&
st
)
==
0
)
{
return
{
cuda_include_path
};
}
#endif
LOG
(
FATAL
)
<<
"Cannot find cuda include path."
<<
"CUDA_PATH is not set or CUDA is not installed in the default "
"installation path."
<<
"In other than linux, it is necessary to set CUDA_PATH."
;
return
{
cuda_include_path
};
}
std
::
vector
<
std
::
string
>
Compiler
::
FindCINNRuntimeIncludePaths
()
{
return
{
Context
::
Global
().
runtime_include_dir
()};
}
std
::
string
Compiler
::
CompileCudaSource
(
const
std
::
string
&
code
,
bool
include_headers
)
{
const
auto
&
header_gen
=
JitSafeHeaderGenerator
::
GetInstance
();
std
::
vector
<
std
::
string
>
compile_options
;
std
::
vector
<
const
char
*>
param_cstrings
{};
nvrtcProgram
prog
;
std
::
string
cc
=
"30"
;
int
major
,
minor
;
cudaError_t
e1
=
cudaDeviceGetAttribute
(
&
major
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaError_t
e2
=
cudaDeviceGetAttribute
(
&
minor
,
cudaDevAttrComputeCapabilityMinor
,
0
);
if
(
e1
==
cudaSuccess
&&
e2
==
cudaSuccess
)
{
cc
=
std
::
to_string
(
major
)
+
std
::
to_string
(
minor
);
}
else
{
LOG
(
WARNING
)
<<
"cannot detect compute capability from your device, "
<<
"fall back to compute_30."
;
}
if
(
compile_to_cubin_
)
{
compile_options
.
push_back
(
"-arch=sm_"
+
cc
);
}
else
{
compile_options
.
push_back
(
"-arch=compute_"
+
cc
);
}
compile_options
.
push_back
(
"-std=c++14"
);
compile_options
.
push_back
(
"-default-device"
);
if
(
include_headers
)
{
// prepare include headers
auto
cuda_headers
=
FindCUDAIncludePaths
();
auto
cinn_headers
=
FindCINNRuntimeIncludePaths
();
std
::
vector
<
std
::
string
>
include_paths
;
for
(
auto
&
header
:
cuda_headers
)
{
include_paths
.
push_back
(
"--include-path="
+
header
);
}
for
(
auto
&
header
:
cinn_headers
)
{
include_paths
.
push_back
(
"--include-path="
+
header
);
}
compile_options
.
insert
(
std
::
end
(
compile_options
),
include_paths
.
begin
(),
include_paths
.
end
());
}
for
(
const
auto
&
option
:
compile_options
)
{
param_cstrings
.
push_back
(
option
.
c_str
());
}
VLOG
(
3
)
<<
"compile options: "
<<
utils
::
Join
(
compile_options
,
" "
);
NVRTC_CALL
(
nvrtcCreateProgram
(
&
prog
,
code
.
c_str
(),
nullptr
,
header_gen
.
size
(),
header_gen
.
headers
().
data
(),
header_gen
.
include_names
().
data
()));
nvrtcResult
compile_res
=
nvrtcCompileProgram
(
prog
,
param_cstrings
.
size
(),
param_cstrings
.
data
());
{
// get log
size_t
log_size
;
NVRTC_CALL
(
nvrtcGetProgramLogSize
(
prog
,
&
log_size
));
std
::
string
log
;
log
.
resize
(
log_size
);
NVRTC_CALL
(
nvrtcGetProgramLog
(
prog
,
&
log
[
0
]));
CHECK_EQ
(
compile_res
,
NVRTC_SUCCESS
)
<<
log
;
}
size_t
size
;
std
::
string
data
;
if
(
compile_to_cubin_
)
{
NVRTC_CALL
(
nvrtcGetCUBINSize
(
prog
,
&
size
));
data
.
resize
(
size
);
NVRTC_CALL
(
nvrtcGetCUBIN
(
prog
,
&
data
[
0
]));
}
else
{
NVRTC_CALL
(
nvrtcGetPTXSize
(
prog
,
&
size
));
data
.
resize
(
size
);
NVRTC_CALL
(
nvrtcGetPTX
(
prog
,
&
data
[
0
]));
}
NVRTC_CALL
(
nvrtcDestroyProgram
(
&
prog
));
return
data
;
}
std
::
string
Compiler
::
CompileWithNvcc
(
const
std
::
string
&
cuda_c
)
{
// read dir source
std
::
string
dir
=
"./source"
;
if
(
access
(
dir
.
c_str
(),
0
)
==
-
1
)
{
CHECK
(
mkdir
(
dir
.
c_str
(),
7
)
!=
-
1
)
<<
"Fail to mkdir "
<<
dir
;
}
// get unqiue prefix name
prefix_name_
=
dir
+
"/"
+
common
::
UniqName
(
"rtc_tmp"
);
auto
cuda_c_file
=
prefix_name_
+
".cu"
;
std
::
ofstream
ofs
(
cuda_c_file
,
std
::
ios
::
out
);
CHECK
(
ofs
.
is_open
())
<<
"Fail to open file "
<<
cuda_c_file
;
ofs
<<
cuda_c
;
ofs
.
close
();
CompileToPtx
();
CompileToCubin
();
return
prefix_name_
+
".cubin"
;
}
// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx",
// std::ios::in); }
void
Compiler
::
CompileToPtx
()
{
auto
include_dir
=
common
::
Context
::
Global
().
runtime_include_dir
();
std
::
string
include_dir_str
=
""
;
for
(
auto
dir
:
include_dir
)
{
if
(
include_dir_str
.
empty
())
{
include_dir_str
=
dir
;
}
else
{
include_dir_str
+=
":"
+
dir
;
}
}
std
::
string
options
=
std
::
string
(
"export PATH="
)
+
FLAGS_cinn_nvcc_cmd_path
+
std
::
string
(
":$PATH && nvcc -std=c++14 --ptx -O3 -I "
)
+
include_dir_str
;
options
+=
" -arch="
+
GetDeviceArch
();
options
+=
" -o "
+
prefix_name_
+
".ptx"
;
options
+=
" "
+
prefix_name_
+
".cu"
;
VLOG
(
2
)
<<
"Nvcc Compile Options : "
<<
options
;
CHECK
(
system
(
options
.
c_str
())
==
0
)
<<
options
;
}
void
Compiler
::
CompileToCubin
()
{
std
::
string
options
=
std
::
string
(
"export PATH="
)
+
FLAGS_cinn_nvcc_cmd_path
+
std
::
string
(
":$PATH && nvcc --cubin -O3"
);
options
+=
" -arch="
+
GetDeviceArch
();
options
+=
" -o "
+
prefix_name_
+
".cubin"
;
options
+=
" "
+
prefix_name_
+
".ptx"
;
VLOG
(
2
)
<<
"Nvcc Compile Options : "
<<
options
;
CHECK
(
system
(
options
.
c_str
())
==
0
)
<<
options
;
}
std
::
string
Compiler
::
GetDeviceArch
()
{
int
major
=
0
,
minor
=
0
;
if
(
cudaDeviceGetAttribute
(
&
major
,
cudaDevAttrComputeCapabilityMajor
,
0
)
==
cudaSuccess
&&
cudaDeviceGetAttribute
(
&
minor
,
cudaDevAttrComputeCapabilityMinor
,
0
)
==
cudaSuccess
)
{
return
"sm_"
+
std
::
to_string
(
major
)
+
std
::
to_string
(
minor
);
}
else
{
LOG
(
WARNING
)
<<
"cannot detect compute capability from your device, "
<<
"fall back to compute_30."
;
return
"sm_30"
;
}
}
std
::
string
Compiler
::
ReadFile
(
const
std
::
string
&
file_name
,
std
::
ios_base
::
openmode
mode
)
{
// open cubin file
std
::
ifstream
ifs
(
file_name
,
mode
);
CHECK
(
ifs
.
is_open
())
<<
"Fail to open file "
<<
file_name
;
ifs
.
seekg
(
std
::
ios
::
end
);
auto
len
=
ifs
.
tellg
();
ifs
.
seekg
(
0
);
// read cubin file
std
::
string
file_data
(
len
,
' '
);
ifs
.
read
(
&
file_data
[
0
],
len
);
ifs
.
close
();
return
std
::
move
(
file_data
);
}
}
// namespace nvrtc
}
// namespace backends
}
// namespace cinn
Prev
1
…
10
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment