Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
833 additions
and
1312 deletions
+833
-1312
paddle/cinn/hlir/framework/CMakeLists.txt
paddle/cinn/hlir/framework/CMakeLists.txt
+9
-9
paddle/cinn/hlir/framework/accuracy_checker.cc
paddle/cinn/hlir/framework/accuracy_checker.cc
+1
-1
paddle/cinn/hlir/framework/accuracy_checker_test.cc
paddle/cinn/hlir/framework/accuracy_checker_test.cc
+1
-1
paddle/cinn/hlir/framework/compile_error.cc
paddle/cinn/hlir/framework/compile_error.cc
+41
-0
paddle/cinn/hlir/framework/compile_error.h
paddle/cinn/hlir/framework/compile_error.h
+68
-0
paddle/cinn/hlir/framework/convert_to_dialect.cc
paddle/cinn/hlir/framework/convert_to_dialect.cc
+0
-55
paddle/cinn/hlir/framework/convert_to_dialect.h
paddle/cinn/hlir/framework/convert_to_dialect.h
+0
-33
paddle/cinn/hlir/framework/graph.cc
paddle/cinn/hlir/framework/graph.cc
+15
-6
paddle/cinn/hlir/framework/graph.h
paddle/cinn/hlir/framework/graph.h
+6
-3
paddle/cinn/hlir/framework/graph_compiler.cc
paddle/cinn/hlir/framework/graph_compiler.cc
+99
-57
paddle/cinn/hlir/framework/graph_compiler.h
paddle/cinn/hlir/framework/graph_compiler.h
+25
-38
paddle/cinn/hlir/framework/graph_compiler_test.cc
paddle/cinn/hlir/framework/graph_compiler_test.cc
+74
-8
paddle/cinn/hlir/framework/graph_compiler_util.cc
paddle/cinn/hlir/framework/graph_compiler_util.cc
+289
-0
paddle/cinn/hlir/framework/graph_compiler_util.h
paddle/cinn/hlir/framework/graph_compiler_util.h
+149
-0
paddle/cinn/hlir/framework/graph_test.cc
paddle/cinn/hlir/framework/graph_test.cc
+1
-1
paddle/cinn/hlir/framework/instruction.cc
paddle/cinn/hlir/framework/instruction.cc
+2
-2
paddle/cinn/hlir/framework/new_ir_compiler.cc
paddle/cinn/hlir/framework/new_ir_compiler.cc
+0
-280
paddle/cinn/hlir/framework/new_ir_compiler.h
paddle/cinn/hlir/framework/new_ir_compiler.h
+0
-77
paddle/cinn/hlir/framework/op_lowering.cc
paddle/cinn/hlir/framework/op_lowering.cc
+0
-600
paddle/cinn/hlir/framework/op_lowering.h
paddle/cinn/hlir/framework/op_lowering.h
+53
-141
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/hlir/framework/CMakeLists.txt
View file @
01a10755
add_subdirectory
(
pir
)
core_gather_headers
()
gather_srcs
(
...
...
@@ -12,22 +13,21 @@ gather_srcs(
program.cc
parallel_compiler.cc
graph_compiler.cc
graph_compiler_util.cc
graph.cc
node.cc
pass.cc
op_strategy.cc
op_lowering.cc
op_lowering_util.cc
op_lowering_impl.cc
accuracy_checker.cc
visualize_helper.cc
)
visualize_helper.cc
compile_error.cc
)
# TODO(Aurelius84):
new_
ir_compiler depends on p
d
_dialect and could
# TODO(Aurelius84):
p
ir_compiler depends on
o
p_dialect
_vjp
and could
# not found under CINN_ONLY mode
if
(
NOT CINN_ONLY
)
cinn_cc_library
(
new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi
pd_dialect
)
cinn_cc_library
(
convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi
cinn_dialect
)
cinn_cc_library
(
pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp
)
endif
()
if
(
WITH_CUDA
)
...
...
@@ -52,5 +52,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore)
cinn_cc_test
(
test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc
DEPS cinncore
)
cinn_cc_test
(
test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore
)
#cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
DEPS cinncore)
cinn_cc_test
(
test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
DEPS cinncore
)
paddle/cinn/hlir/framework/accuracy_checker.cc
View file @
01a10755
...
...
@@ -21,7 +21,7 @@
#include <cuda_runtime.h>
#endif
DECLARE_int64
(
cinn_self_check_accuracy_num
);
PD_
DECLARE_int64
(
cinn_self_check_accuracy_num
);
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/framework/accuracy_checker_test.cc
View file @
01a10755
...
...
@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
DECLARE_string
(
cinn_self_check_accuracy
);
PD_
DECLARE_string
(
cinn_self_check_accuracy
);
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/framework/compile_error.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/utils/enum_string.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
std
::
string
CompileErrorHandler
::
GeneralErrorMessage
()
const
{
std
::
ostringstream
os
;
os
<<
"[CompileError] An error occurred during compilation with the error "
"code: "
<<
utils
::
Enum2String
(
status_
)
<<
std
::
endl
;
os
<<
"(at "
<<
file_
<<
" : "
<<
line_
<<
")"
<<
std
::
endl
;
os
<<
indent_str_
<<
"[Error info] "
<<
this
->
err_msg_
<<
std
::
endl
;
return
os
.
str
();
}
std
::
string
CompileErrorHandler
::
DetailedErrorMessage
()
const
{
std
::
ostringstream
os
;
os
<<
GeneralErrorMessage
();
os
<<
indent_str_
<<
"[Detail info] "
<<
detail_info_
<<
std
::
endl
;
return
os
.
str
();
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/compile_error.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
/**
* This handler is used to deal with the errors during the compilation process
*/
class
CompileErrorHandler
:
public
utils
::
ErrorHandler
{
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit
CompileErrorHandler
(
const
CompilationStatus
&
status
,
const
std
::
string
&
err_msg
,
const
std
::
string
&
detail_info
,
const
char
*
file
,
int
line
)
:
status_
(
status
),
err_msg_
(
err_msg
),
detail_info_
(
detail_info
),
file_
(
file
),
line_
(
line
)
{}
/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std
::
string
GeneralErrorMessage
()
const
;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std
::
string
DetailedErrorMessage
()
const
;
CompilationStatus
Status
()
const
{
return
status_
;
}
private:
CompilationStatus
status_
;
std
::
string
err_msg_
;
std
::
string
detail_info_
;
const
char
*
file_
;
int
line_
;
};
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/convert_to_dialect.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include <string>
#include <unordered_map>
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
std
::
unique_ptr
<::
ir
::
Program
>
ConvertToRuntimeDialect
(
const
hlir
::
framework
::
Program
&
program
)
{
::
ir
::
IrContext
*
ctx
=
::
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
cinn
::
dialect
::
RuntimeDialect
>
();
auto
ir_program
=
std
::
make_unique
<::
ir
::
Program
>
(
ctx
);
std
::
string
jit_op_name
=
dialect
::
JitKernelOp
::
name
();
::
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
jit_op_name
);
auto
&
instrs
=
program
.
GetRunInstructions
();
for
(
auto
&
instr
:
instrs
)
{
std
::
unordered_map
<
std
::
string
,
::
ir
::
Attribute
>
op_attrs
{
{
dialect
::
JitKernelOp
::
kAttrName
,
::
ir
::
PointerAttribute
::
get
(
ctx
,
instr
.
get
())},
};
::
ir
::
Operation
*
cinn_op
=
::
ir
::
Operation
::
Create
({},
op_attrs
,
{},
op_info
);
ir_program
->
block
()
->
push_back
(
cinn_op
);
}
return
std
::
move
(
ir_program
);
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/convert_to_dialect.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
namespace
ir
{
class
Program
;
}
// namespace ir
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
class
Program
;
std
::
unique_ptr
<::
ir
::
Program
>
ConvertToRuntimeDialect
(
const
hlir
::
framework
::
Program
&
program
);
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/graph.cc
View file @
01a10755
...
...
@@ -18,10 +18,14 @@
#include <sstream>
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#endif
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string
(
cinn_fusion_groups_graphviz_dir
);
PD_
DECLARE_string
(
cinn_fusion_groups_graphviz_dir
);
namespace
cinn
{
namespace
hlir
{
...
...
@@ -309,15 +313,20 @@ void Graph::VisualizeGroupedGraph(
}
// Dump debug info for each group
LOG
(
INFO
)
<<
"Dump graph debug info to: "
<<
FLAGS_cinn_fusion_groups_graphviz_dir
;
V
LOG
(
4
)
<<
"Dump graph debug info to: "
<<
FLAGS_cinn_fusion_groups_graphviz_dir
;
const
auto
&
groups
=
RemoveAccCheckGroups
(
origin_groups
);
const
auto
&
group_dots
=
VisualizeGroups
(
groups
,
fetch_var_ids
);
for
(
int
idx
=
0
;
idx
<
groups
.
size
();
++
idx
)
{
// Create fusion_group_x folder
int
device_id
=
0
;
#ifdef CINN_WITH_CUDA
cudaGetDevice
(
&
device_id
);
#endif
auto
group_path
=
utils
::
StringFormat
(
"%s/fusion_group_%d"
,
utils
::
StringFormat
(
"%s/
device_%d/
fusion_group_%d"
,
FLAGS_cinn_fusion_groups_graphviz_dir
.
c_str
(),
device_id
,
idx
);
if
(
!
MakeDirectory
(
group_path
,
S_IRWXU
|
S_IRGRP
|
S_IXGRP
|
S_IROTH
|
S_IXOTH
))
{
...
...
@@ -468,7 +477,7 @@ std::vector<std::string> Graph::VisualizeGroups(
return
dot_vec
;
}
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetInputNodeDatas
()
{
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetInputNodeDatas
()
const
{
std
::
unordered_set
<
NodeData
*>
group_inputs
;
// count all node's input data
...
...
@@ -498,7 +507,7 @@ std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
return
group_inputs
;
}
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetOutputNodeDatas
()
{
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetOutputNodeDatas
()
const
{
std
::
unordered_set
<
NodeData
*>
group_outputs
;
for
(
auto
node
:
this
->
output_nodes
)
{
...
...
paddle/cinn/hlir/framework/graph.h
View file @
01a10755
...
...
@@ -26,6 +26,7 @@
#include "paddle/cinn/hlir/framework/node.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
...
...
@@ -59,6 +60,8 @@ class Graph : public cinn::common::Graph {
std
::
vector
<
std
::
vector
<
Node
*>>
groups
;
struct
Group
{
Group
()
=
default
;
Group
(
const
Group
&
)
=
delete
;
Group
(
Group
&&
)
=
delete
;
explicit
Group
(
const
Graph
*
graph
)
:
graph_
(
graph
)
{}
...
...
@@ -109,7 +112,7 @@ class Graph : public cinn::common::Graph {
}
};
std
::
vector
<
Node
*>
CollectNodes
()
{
std
::
vector
<
Node
*>
CollectNodes
()
const
{
if
(
fused_sub_groups
.
size
())
{
std
::
vector
<
Node
*>
tmp_nodes
;
for
(
auto
&
group
:
fused_sub_groups
)
{
...
...
@@ -144,8 +147,8 @@ class Graph : public cinn::common::Graph {
return
node_set
;
}
std
::
unordered_set
<
NodeData
*>
GetInputNodeDatas
();
std
::
unordered_set
<
NodeData
*>
GetOutputNodeDatas
();
std
::
unordered_set
<
NodeData
*>
GetInputNodeDatas
()
const
;
std
::
unordered_set
<
NodeData
*>
GetOutputNodeDatas
()
const
;
std
::
string
GetFuncName
()
{
return
"fn_"
+
group_id
+
unique_id
;
}
...
...
paddle/cinn/hlir/framework/graph_compiler.cc
View file @
01a10755
...
...
@@ -29,8 +29,11 @@
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
...
...
@@ -40,90 +43,124 @@ using cinn::common::float16;
std
::
unique_ptr
<
Program
>
GraphCompiler
::
Build
(
const
std
::
string
&
code
)
{
utils
::
RecordEvent
(
"GraphCompiler::Build"
,
utils
::
EventType
::
kGraph
);
GraphCompiler
::
CompileOptions
options
;
options
.
attached_code
=
code
;
options
.
with_instantiate_variables
=
true
;
auto
&&
result
=
Build
(
options
);
return
std
::
move
(
result
.
runtime_program
);
}
compilation_context_
.
ApplySourceCode
(
code
);
compilation_context_
.
with_instantiate_variables
=
true
;
void
GraphCompiler
::
CompileOptions
::
Apply
(
const
auto_schedule
::
TuningResult
&
tuning_result
)
{
// assign options with TuningResult directly
groups
.
assign
(
tuning_result
.
subgraphs
.
begin
(),
tuning_result
.
subgraphs
.
end
());
lowered_funcs
.
assign
(
tuning_result
.
function_groups
.
begin
(),
tuning_result
.
function_groups
.
end
());
auto
&&
result
=
Build
(
&
compilation_context_
);
return
result
.
RuntimeProgram
();
}
GraphCompiler
::
CompilationResult
GraphCompiler
::
Build
(
const
GraphCompiler
::
CompileOptions
&
options
,
std
::
unordered_set
<
std
::
string
>&&
fetch_var_ids
,
void
*
stream
)
{
CompilationResult
GraphCompiler
::
Build
(
CompilationContext
*
context
)
{
Context
::
Global
().
ResetNameId
();
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_
->
VisualizeGroupedGraph
(
fetch_var_ids
.
empty
()
?
fetch_var_ids_
:
fetch_var_ids
);
context
->
graph
->
VisualizeGroupedGraph
(
context
->
fetch_var_ids
);
if
(
options
.
with_instantiate_variables
)
{
InstantiateVariables
();
if
(
context
->
with_instantiate_variables
)
{
InstantiateVariables
(
context
);
}
VLOG
(
2
)
<<
"Compile With Parallel Compiler!"
;
utils
::
RecordEvent
(
"GraphCompiler CompileResult"
,
utils
::
EventType
::
kOrdinary
);
ParallelCompiler
::
CompileOptions
option
;
option
.
lowered_funcs
=
options
.
lowered_funcs
;
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
scope_
,
graph_
,
option
,
target_
);
auto
result
=
(
*
parallel_compiler_
.
get
())();
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
// Dump compilation result
backends
::
CompilationInfoDumper
dumper
(
result
);
if
(
context
->
stage
!=
CompilationStage
::
DEFAULT
||
!
result
.
IsSuccess
())
{
return
result
;
}
if
(
options
.
remove_unused_variables
)
{
RemoveInvalidVariables
(
result
.
i
nstructions
);
if
(
context
->
remove_unused_variables
)
{
RemoveInvalidVariables
(
context
,
result
.
RuntimeI
nstructions
()
);
}
if
(
options
.
with_buffer_handle_instruction_inserted
)
{
if
(
context
->
with_buffer_handle_instruction_inserted
)
{
VLOG
(
3
)
<<
"option.with_buffer_handle_instruction_inserted enable"
;
InsertBufferHandlers
(
&
result
.
instructions
);
InsertBufferHandlers
(
context
,
&
result
.
instructions
_
);
}
VLOG
(
2
)
<<
"Compile With Parallel Compiler Done!"
;
GraphCompiler
::
CompilationResult
compilation_result
;
compilation_result
.
runtime_program
.
reset
(
new
Program
(
scope_
,
std
::
move
(
result
.
instructions
)));
return
compilation_result
;
result
.
SetRuntimeProgram
(
std
::
make_unique
<
Program
>
(
context
->
scope
,
std
::
move
(
result
.
instructions_
)));
return
result
;
}
CompilationResult
GraphCompiler
::
Lowering
()
{
return
Lowering
(
&
compilation_context_
);
}
void
GraphCompiler
::
InstantiateVariables
()
{
CompilationResult
GraphCompiler
::
Lowering
(
CompilationContext
*
context
)
{
// Global setting
Context
::
Global
().
ResetNameId
();
// Setting compile options
VLOG
(
2
)
<<
"Compile With Parallel Compiler! But just lowering!"
;
context
->
stage
=
CompilationStage
::
LOWERING
;
// Compile with parallel compiler
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
return
result
;
}
CompilationResult
GraphCompiler
::
CodegenAndJit
()
{
return
CodegenAndJit
(
&
compilation_context_
);
}
CompilationResult
GraphCompiler
::
CodegenAndJit
(
CompilationContext
*
context
)
{
// Global setting
Context
::
Global
().
ResetNameId
();
// Setting compile options
VLOG
(
2
)
<<
"Compile With Parallel Compiler! But just codegen and jit!"
;
context
->
stage
=
CompilationStage
::
CODEGEN_AND_JIT
;
// Compile with parallel compiler
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
return
result
;
}
CompilationResult
GraphCompiler
::
BuildInstruction
()
{
return
BuildInstruction
(
&
compilation_context_
);
}
CompilationResult
GraphCompiler
::
BuildInstruction
(
CompilationContext
*
context
)
{
// Global setting
Context
::
Global
().
ResetNameId
();
// Setting compile options
VLOG
(
2
)
<<
"Compile With Parallel Compiler! But just build instruction!"
;
context
->
stage
=
CompilationStage
::
BUILD_INSTRUCTION
;
// Compile with parallel compiler
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
return
result
;
}
void
GraphCompiler
::
InstantiateVariables
(
CompilationContext
*
context
)
{
VLOG
(
3
)
<<
"Instantiate all variables on compile-time"
;
utils
::
RecordEvent
(
"GraphCompiler MutableData"
,
utils
::
EventType
::
kOrdinary
);
// All variables reside in scope_, so traverse it to instantiate each one
for
(
auto
&
name
:
scope_
->
var_names
())
{
auto
*
var
=
scope_
->
Var
<
Tensor
>
(
std
::
string
({
name
.
data
(),
name
.
size
()}));
for
(
auto
&
name
:
context
->
scope
->
var_names
())
{
auto
*
var
=
context
->
scope
->
Var
<
Tensor
>
(
std
::
string
({
name
.
data
(),
name
.
size
()}));
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
if
(
reuse_vars_map
_
.
count
(
name
))
{
auto
src_var_name
=
reuse_vars_map
_
.
at
(
name
);
auto
*
src_var
=
scope
_
->
Var
<
Tensor
>
(
src_var_name
);
if
(
context
->
reuse_vars_map
.
count
(
name
))
{
auto
src_var_name
=
context
->
reuse_vars_map
.
at
(
name
);
auto
*
src_var
=
context
->
scope
->
Var
<
Tensor
>
(
src_var_name
);
auto
&
src_tensor
=
absl
::
get
<
Tensor
>
(
*
src_var
);
tensor
->
set_buffer
(
src_tensor
->
get_buffer
());
}
else
{
tensor
->
mutable_data
(
target
_
,
tensor
->
type
());
tensor
->
mutable_data
(
context
->
target
,
tensor
->
type
());
}
}
}
void
GraphCompiler
::
RemoveInvalidVariables
(
CompilationContext
*
context
,
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
instructions
)
{
// mark all variables are invalid initially
utils
::
RecordEvent
(
"GraphCompiler RemoveInvalidVariables"
,
utils
::
EventType
::
kOrdinary
);
std
::
unordered_set
<
std
::
string
>
invalid_variables
;
auto
var_names
=
scope
_
->
var_names
();
auto
var_names
=
context
->
scope
->
var_names
();
invalid_variables
.
reserve
(
var_names
.
size
());
std
::
transform
(
var_names
.
begin
(),
...
...
@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables(
<<
" invalid variables to be removed from scope"
;
std
::
for_each
(
invalid_variables
.
begin
(),
invalid_variables
.
end
(),
[
this
](
const
std
::
string
&
var_name
)
{
scope
_
->
EraseVar
(
var_name
);
[
context
](
const
std
::
string
&
var_name
)
{
context
->
scope
->
EraseVar
(
var_name
);
VLOG
(
3
)
<<
"Variable("
<<
var_name
<<
") is erased"
;
});
}
...
...
@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime(
}
void
GraphCompiler
::
InsertBufferHandlers
(
CompilationContext
*
context
,
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>*
instructions
)
{
utils
::
RecordEvent
(
"GraphCompiler InsertBufferHandlers"
,
utils
::
EventType
::
kOrdinary
);
...
...
@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers(
auto
function_name
=
"malloc_buffer_instruction_"
+
std
::
to_string
(
step
);
auto
malloc_instr
=
std
::
make_unique
<
Instruction
>
(
common
::
DefaultHostTarget
(),
scope
_
.
get
(),
context
->
scope
.
get
(),
malloc_var_names
,
std
::
vector
<
std
::
string
>
({}),
function_name
);
...
...
@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers(
auto
function_name
=
"free_buffer_instruction_"
+
std
::
to_string
(
step
);
auto
free_instr
=
std
::
make_unique
<
Instruction
>
(
common
::
DefaultHostTarget
(),
scope
_
.
get
(),
context
->
scope
.
get
(),
std
::
vector
<
std
::
string
>
({}),
free_var_names
,
function_name
);
...
...
@@ -336,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
poly
::
StageMap
stages
=
C
.
back
();
std
::
string
func_name_prefix
=
"fn_"
;
auto
funcs
=
lang
::
LowerVec
(
func_name_prefix
+
node_id
,
stages
,
all_arg_tensors
,
{},
{},
nullptr
,
target
,
true
);
ast_gen_ius
::
TensorGroup
tensor_group
=
ast_gen_ius
::
ConvertStageMapToTensorGroup
(
stages
);
auto
funcs
=
lang
::
LowerToAstVec
(
func_name_prefix
+
node_id
,
all_arg_tensors
,
&
tensor_group
,
target
);
VLOG
(
4
)
<<
"Lower op: "
<<
node_id
<<
", get "
<<
funcs
.
size
()
<<
" LoweredFunc:
\n
"
;
for
(
auto
fun
:
funcs
)
{
VLOG
(
4
)
<<
fun
;
}
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
for
(
int
i
=
0
;
i
<
C
.
size
()
-
1
;
++
i
)
{
...
...
@@ -390,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
optim
::
OptimizeExprGPU
(
&
(
funcs_after_schedule
[
i
]
->
body
));
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
all_arg_tensors
,
stages
,
funcs_after_schedule
[
i
]
->
body
);
all_arg_tensors
,
tensor_group
,
funcs_after_schedule
[
i
]
->
body
);
funcs_after_schedule
[
i
]
->
temp_bufs
=
temp_buffers
;
funcs_after_schedule
[
i
]
=
ir
::
_LoweredFunc_
::
Make
(
funcs_after_schedule
[
i
]
->
name
,
...
...
paddle/cinn/hlir/framework/graph_compiler.h
View file @
01a10755
...
...
@@ -28,6 +28,7 @@
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
...
...
@@ -46,48 +47,41 @@ namespace framework {
*/
class
GraphCompiler
final
{
public:
GraphCompiler
(
Target
target
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
shared_ptr
<
Graph
>&
graph
)
:
target_
(
std
::
move
(
target
)),
scope_
(
scope
),
graph_
(
graph
)
{}
struct
CompilationResult
{
std
::
unique_ptr
<
Program
>
runtime_program
;
};
struct
CompileOptions
{
std
::
string
attached_code
=
""
;
bool
with_instantiate_variables
=
false
;
bool
with_buffer_handle_instruction_inserted
=
false
;
bool
remove_unused_variables
=
true
;
// nodes group, it may come from the result of op fusion or graph tuning.
// nodes in a group will be built into an Instruction
std
::
vector
<
std
::
shared_ptr
<
Graph
::
Group
>>
groups
;
// corresponding LoweredFuncs of above grouped nodes,
// if it is empty then graph_compiler will generate for them
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
// apply results of auto-tune to compile
void
Apply
(
const
auto_schedule
::
TuningResult
&
tuning_result
);
};
GraphCompiler
(
CompilationContext
context
)
:
compilation_context_
(
context
)
{}
// Compile with a packing option and result, to be extended easily.
CompilationResult
Build
(
const
CompileOptions
&
options
,
std
::
unordered_set
<
std
::
string
>&&
fetch_var_ids
=
{},
void
*
stream
=
nullptr
);
CompilationResult
Build
(
CompilationContext
*
context
);
std
::
unique_ptr
<
Program
>
Build
(
const
std
::
string
&
code
=
""
);
const
std
::
shared_ptr
<
Scope
>&
GetScope
()
const
{
return
scope_
;
}
CompilationResult
Lowering
();
CompilationResult
Lowering
(
CompilationContext
*
context
);
CompilationResult
CodegenAndJit
();
CompilationResult
CodegenAndJit
(
CompilationContext
*
context
);
CompilationResult
BuildInstruction
();
CompilationResult
BuildInstruction
(
CompilationContext
*
context
);
const
std
::
shared_ptr
<
Scope
>&
GetScope
()
const
{
return
compilation_context_
.
scope
;
}
CompilationContext
&
GetCompilationContext
()
{
return
compilation_context_
;
}
void
SetCompilationContext
(
const
CompilationContext
&
context
)
{
compilation_context_
=
context
;
}
private:
// instantiate all variables on compile time
void
InstantiateVariables
();
void
InstantiateVariables
(
CompilationContext
*
context
);
// some variables are eliminated by optimized passes(such as OpFusion),
// we can filter out them according to arguments of the built instructions,
// and erase them from the scope to avoid unnecessary buffer allocation
void
RemoveInvalidVariables
(
CompilationContext
*
context
,
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
instructions
);
// find the first and last instruction where a variable used, and mark the
...
...
@@ -102,21 +96,14 @@ class GraphCompiler final {
// firstly used in the next instruction, and insert a buffer free instruction
// applying on variables after no instruction will use them anymore
void
InsertBufferHandlers
(
CompilationContext
*
context
,
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>*
instructions
);
private:
// parallel compiler
std
::
shared_ptr
<
ParallelCompiler
>
parallel_compiler_
;
Target
target_
;
std
::
shared_ptr
<
Graph
>
graph_
;
std
::
shared_ptr
<
Scope
>
scope_
;
// fetch var ids in cinn and the corresponding var nodes will not be fused so
// as to get the result
std
::
unordered_set
<
std
::
string
>
fetch_var_ids_
;
// map dst reuse var to the src var sharing buffer
absl
::
flat_hash_map
<
std
::
string
,
std
::
string
>
reuse_vars_map_
;
CompilationContext
compilation_context_
;
CINN_DISALLOW_COPY_AND_ASSIGN
(
GraphCompiler
);
};
...
...
paddle/cinn/hlir/framework/graph_compiler_test.cc
View file @
01a10755
...
...
@@ -19,6 +19,7 @@
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h"
...
...
@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) {
ASSERT_EQ
(
scope
->
var_names
().
size
(),
6
);
EXPECT_NE
(
scope
->
FindVar
(
c
->
id
),
nullptr
);
GraphCompiler
gc
(
target
,
scope
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
ASSERT_EQ
(
scope
->
var_names
().
size
(),
3
);
EXPECT_EQ
(
scope
->
FindVar
(
c
->
id
),
nullptr
);
...
...
@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
Graph
Compil
er
gc
_disable
(
target
,
scope
,
graph
);
GraphCompiler
::
CompileOptions
options
;
Compil
ationContext
context
_disable
(
graph
,
scope
,
target
);
GraphCompiler
gc_disable
(
context_disable
)
;
// disable with_buffer_handle_instruction_inserted: only 1 instruction
auto
runtime_program_disable
=
gc_disable
.
Build
(
options
).
runtime_program
;
auto
runtime_program_disable
=
gc_disable
.
Build
(
&
context_disable
).
RuntimeProgram
();
ASSERT_EQ
(
runtime_program_disable
->
size
(),
1
);
const
auto
&
computation_instr_disable
=
runtime_program_disable
->
GetRunInstructions
().
front
();
...
...
@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
// enable with_buffer_handle_instruction_inserted: 3 instructions, 1st ->
// malloc instruction(a, b, d), 2nd -> the real computation
// instruction(add + relu) and 3rd -> free instruction
GraphCompiler
gc_enable
(
target
,
scope
,
graph
);
options
.
with_buffer_handle_instruction_inserted
=
true
;
auto
runtime_program_enable
=
gc_enable
.
Build
(
options
).
runtime_program
;
CompilationContext
context_enable
(
graph
,
scope
,
target
);
context_enable
.
with_buffer_handle_instruction_inserted
=
true
;
GraphCompiler
gc_enable
(
context_enable
);
auto
runtime_program_enable
=
gc_enable
.
Build
(
&
context_enable
).
RuntimeProgram
();
const
auto
&
instructions
=
runtime_program_enable
->
GetRunInstructions
();
ASSERT_EQ
(
instructions
.
size
(),
3
);
...
...
@@ -193,7 +198,8 @@ void RunCublas(
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
GraphCompiler
gc
(
target
,
scope
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
auto
exe_program
=
gc
.
Build
();
auto
data_a
=
scope
->
GetTensor
(
"A"
);
...
...
@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) {
RunCublas
(
64
,
128
,
128
,
true
,
true
);
}
TEST
(
GraphCompilerTest
,
TestLowering
)
{
frontend
::
NetBuilder
builder
(
"test_lowering_on_graph_compiler"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
program
=
builder
.
Build
();
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
CompilationResult
result
=
gc
.
Lowering
();
ASSERT_EQ
(
result
.
Status
(),
CompilationStatus
::
SUCCESS
);
}
TEST
(
GraphCompilerTest
,
TestCodegenAndJit
)
{
frontend
::
NetBuilder
builder
(
"test_codegen_and_jit_on_graph_compiler"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
program
=
builder
.
Build
();
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
CompilationResult
result
=
gc
.
CodegenAndJit
();
ASSERT_EQ
(
result
.
Status
(),
CompilationStatus
::
SUCCESS
);
}
TEST
(
GraphCompilerTest
,
TestBuildInstruction
)
{
frontend
::
NetBuilder
builder
(
"test_build_instruction_on_graph_compiler"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
program
=
builder
.
Build
();
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
CompilationResult
result
=
gc
.
BuildInstruction
();
ASSERT_EQ
(
result
.
Status
(),
CompilationStatus
::
SUCCESS
);
}
#endif
}
// namespace framework
...
...
paddle/cinn/hlir/framework/graph_compiler_util.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
void
CompilationContext
::
ApplyTuningResult
(
const
auto_schedule
::
TuningResult
&
tuning_result
)
{
// assign options with TuningResult directly
groups
.
assign
(
tuning_result
.
subgraphs
.
begin
(),
tuning_result
.
subgraphs
.
end
());
lowered_funcs
.
assign
(
tuning_result
.
function_groups
.
begin
(),
tuning_result
.
function_groups
.
end
());
}
void
CompilationContext
::
ApplySourceCode
(
const
std
::
string
&
code
)
{
attached_source_code
=
code
;
}
void
CompilationResult
::
InitCompilationResult
(
int
group_size
)
{
size_
=
group_size
;
status_
.
resize
(
group_size
,
CompilationStatus
::
SUCCESS
);
messages_
.
resize
(
group_size
);
for
(
int
idx
=
0
;
idx
<
group_size
;
++
idx
)
{
messages_
[
idx
]
=
"Group Idx: "
+
std
::
to_string
(
idx
)
+
", Compile Success.
\n
"
;
}
lowered_funcs_
.
resize
(
group_size
,
std
::
nullopt
);
source_codes_
.
resize
(
group_size
,
std
::
nullopt
);
source_ptxs_
.
resize
(
group_size
,
std
::
nullopt
);
instructions_
.
resize
(
group_size
);
}
void
CompilationResult
::
SetStatus
(
int
idx
,
const
CompilationStatus
&
status
)
{
if
(
idx
<
status_
.
size
())
{
status_
[
idx
]
=
status
;
}
}
void
CompilationResult
::
SetMessage
(
int
idx
,
const
std
::
string
&
message
)
{
if
(
idx
<
messages_
.
size
())
{
messages_
[
idx
]
=
message
;
}
}
void
CompilationResult
::
SetLoweredFuncs
(
int
idx
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
funcs
)
{
if
(
idx
<
lowered_funcs_
.
size
())
{
lowered_funcs_
[
idx
]
=
funcs
;
}
}
void
CompilationResult
::
SetSourceCode
(
int
idx
,
const
std
::
string
&
source_code
)
{
if
(
idx
<
source_codes_
.
size
())
{
source_codes_
[
idx
]
=
source_code
;
}
}
void
CompilationResult
::
SetSourcePtx
(
int
idx
,
const
std
::
string
&
source_ptx
)
{
if
(
idx
<
source_ptxs_
.
size
())
{
source_ptxs_
[
idx
]
=
source_ptx
;
}
}
void
CompilationResult
::
SetInstruction
(
int
idx
,
std
::
unique_ptr
<
Instruction
>
instruction
)
{
if
(
idx
<
instructions_
.
size
())
{
instructions_
[
idx
]
=
std
::
move
(
instruction
);
}
}
void
CompilationResult
::
SetRuntimeProgram
(
std
::
unique_ptr
<
Program
>
runtime_program
)
{
runtime_program_
=
std
::
move
(
runtime_program
);
}
bool
CompilationResult
::
IsSuccess
()
const
{
for
(
const
CompilationStatus
&
s
:
status_
)
{
if
(
s
!=
CompilationStatus
::
SUCCESS
)
{
return
false
;
}
}
return
true
;
}
CompilationStatus
CompilationResult
::
Status
()
const
{
CompilationStatus
worst_status
=
CompilationStatus
::
SUCCESS
;
for
(
const
CompilationStatus
&
s
:
status_
)
{
if
(
s
<
worst_status
)
{
worst_status
=
s
;
}
}
return
worst_status
;
}
CompilationStatus
CompilationResult
::
Status
(
int
idx
)
const
{
if
(
idx
>=
status_
.
size
())
{
return
CompilationStatus
::
UNKNOWN_FAIL
;
}
return
status_
[
idx
];
}
std
::
string
CompilationResult
::
Message
()
const
{
std
::
string
res
;
for
(
int
idx
=
0
;
idx
<
messages_
.
size
();
++
idx
)
{
res
+=
messages_
[
idx
];
}
return
res
;
}
std
::
string
CompilationResult
::
Message
(
int
idx
)
const
{
if
(
idx
>=
messages_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
return
messages_
[
idx
];
}
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
CompilationResult
::
LoweredFuncs
()
const
{
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
res
(
lowered_funcs_
.
size
());
for
(
int
idx
=
0
;
idx
<
lowered_funcs_
.
size
();
++
idx
)
{
if
(
lowered_funcs_
[
idx
].
has_value
())
{
res
[
idx
]
=
lowered_funcs_
[
idx
].
value
();
}
else
{
std
::
stringstream
ss
;
ss
<<
"LoweredFuncs of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the lower "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
res
;
}
std
::
vector
<
ir
::
LoweredFunc
>
CompilationResult
::
LoweredFuncs
(
int
idx
)
const
{
if
(
idx
>=
lowered_funcs_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
if
(
!
lowered_funcs_
[
idx
].
has_value
())
{
std
::
stringstream
ss
;
ss
<<
"LoweredFuncs of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the lower process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
lowered_funcs_
[
idx
].
value
();
}
std
::
vector
<
std
::
string
>
CompilationResult
::
SourceCodes
()
const
{
std
::
vector
<
std
::
string
>
res
(
source_codes_
.
size
());
for
(
int
idx
=
0
;
idx
<
source_codes_
.
size
();
++
idx
)
{
if
(
source_codes_
[
idx
].
has_value
())
{
res
[
idx
]
=
source_codes_
[
idx
].
value
();
}
else
{
std
::
stringstream
ss
;
ss
<<
"Source Code of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the codegen "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
res
;
}
std
::
string
CompilationResult
::
SourceCode
(
int
idx
)
const
{
if
(
idx
>=
source_codes_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
if
(
!
source_codes_
[
idx
].
has_value
())
{
std
::
stringstream
ss
;
ss
<<
"Source Code of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the codegen "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
source_codes_
[
idx
].
value
();
}
std
::
vector
<
std
::
string
>
CompilationResult
::
SourcePtxs
()
const
{
std
::
vector
<
std
::
string
>
res
(
source_ptxs_
.
size
());
for
(
int
idx
=
0
;
idx
<
source_ptxs_
.
size
();
++
idx
)
{
if
(
source_ptxs_
[
idx
].
has_value
())
{
res
[
idx
]
=
source_ptxs_
[
idx
].
value
();
}
else
{
std
::
stringstream
ss
;
ss
<<
"Source PTX of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the nvrtc compile "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
res
;
}
std
::
string
CompilationResult
::
SourcePtx
(
int
idx
)
const
{
if
(
idx
>=
source_ptxs_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
if
(
!
source_ptxs_
[
idx
].
has_value
())
{
std
::
stringstream
ss
;
ss
<<
"Source PTX of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the nvrtc compile "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
source_ptxs_
[
idx
].
value
();
}
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
CompilationResult
::
RuntimeInstructions
()
const
{
if
(
runtime_program_
!=
nullptr
)
{
return
runtime_program_
->
GetRunInstructions
();
}
for
(
int
idx
=
0
;
idx
<
instructions_
.
size
();
++
idx
)
{
if
(
instructions_
[
idx
]
==
nullptr
)
{
std
::
stringstream
ss
;
ss
<<
"Instruction of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the build "
"instruction process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
instructions_
;
}
const
std
::
unique_ptr
<
Instruction
>&
CompilationResult
::
RuntimeInstruction
(
int
idx
)
const
{
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
insts
=
runtime_program_
?
runtime_program_
->
GetRunInstructions
()
:
instructions_
;
if
(
idx
>=
insts
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
insts
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
return
insts
[
idx
];
}
std
::
unique_ptr
<
Program
>
CompilationResult
::
RuntimeProgram
()
{
if
(
runtime_program_
==
nullptr
)
{
std
::
stringstream
ss
;
ss
<<
"Runtime program is not generated.
\n
"
<<
"Some errors may have occurred during the compilation process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
std
::
move
(
runtime_program_
);
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/graph_compiler_util.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/cinn/ir/lowered_func.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
// An enum class used to control the compilation stage.
enum
class
CompilationStage
{
// Fully compiled by default, the following compilation result can be
// obtained: lowered_function, source_code, source_ptx, instruction and
// runtime_program.
DEFAULT
=
0
,
// Just do lowering, we can only get lowered_function from compilation result.
LOWERING
=
1
,
// Stop after codegen and jit, we can get: lowered_function, source_code and
// source_ptx from compilation result.
CODEGEN_AND_JIT
=
2
,
// Stop after build instruction, we can get: lowered_function, source_code,
// source_ptx and runtime_program from compilation result.
BUILD_INSTRUCTION
=
3
,
};
// An enum class used to represent the compilation status.
enum
class
CompilationStatus
{
// An unknown error occurred during compilation.
UNKNOWN_FAIL
=
0
,
// An error occurred during lowering.
LOWERING_FAIL
=
1
,
// An error occurred during codegen and jit.
CODEGEN_JIT_FAIL
=
2
,
// An error occurred during build instruction.
INSTUCTION_FAIL
=
3
,
// An error occurred during build runtime program.
PROGRAM_FAIL
=
4
,
// Compile successfully.
SUCCESS
=
5
,
};
struct
CompilationContext
{
CompilationContext
()
=
default
;
CompilationContext
(
const
std
::
shared_ptr
<
Graph
>&
graph
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
Target
&
target
)
:
graph
(
graph
),
scope
(
scope
),
target
(
target
)
{}
std
::
string
attached_source_code
=
""
;
// Compile options.
bool
with_instantiate_variables
=
false
;
bool
with_buffer_handle_instruction_inserted
=
false
;
bool
remove_unused_variables
=
true
;
// Compile stage, full compile by default.
CompilationStage
stage
=
CompilationStage
::
DEFAULT
;
// Compile target.
Target
target
;
// Computation graph.
std
::
shared_ptr
<
Graph
>
graph
;
// Variable scope
std
::
shared_ptr
<
Scope
>
scope
;
// Fetch var ids in cinn and the corresponding var nodes will not be fused
// so as to get the result.
std
::
unordered_set
<
std
::
string
>
fetch_var_ids
;
// Map dst reuse var to the src var sharing buffer
absl
::
flat_hash_map
<
std
::
string
,
std
::
string
>
reuse_vars_map
;
// Nodes group, it may come from the result of op fusion or graph tuning.
// Nodes in a group will be built into an Instruction.
std
::
vector
<
std
::
shared_ptr
<
Graph
::
Group
>>
groups
;
// Corresponding lowered functions of above grouped nodes,
// if it is empty then graph_compiler will generate for them.
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
// CUDA stream.
void
*
stream
=
nullptr
;
// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
void
ApplySourceCode
(
const
std
::
string
&
code
);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
// results are applied.
void
ApplyTuningResult
(
const
auto_schedule
::
TuningResult
&
tuning_result
);
};
class
GraphCompiler
;
class
CompilationResult
{
friend
class
GraphCompiler
;
public:
void
InitCompilationResult
(
int
group_size
);
// Setters
void
SetStatus
(
int
idx
,
const
CompilationStatus
&
status
);
void
SetMessage
(
int
idx
,
const
std
::
string
&
message
);
void
SetLoweredFuncs
(
int
idx
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
funcs
);
void
SetSourceCode
(
int
idx
,
const
std
::
string
&
source_code
);
void
SetSourcePtx
(
int
idx
,
const
std
::
string
&
source_ptx
);
void
SetInstruction
(
int
idx
,
std
::
unique_ptr
<
Instruction
>
instruction
);
void
SetRuntimeProgram
(
std
::
unique_ptr
<
Program
>
runtime_program
);
// Getters
bool
IsSuccess
()
const
;
int
Size
()
const
{
return
size_
;
}
CompilationStatus
Status
()
const
;
CompilationStatus
Status
(
int
idx
)
const
;
std
::
string
Message
()
const
;
std
::
string
Message
(
int
idx
)
const
;
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
LoweredFuncs
()
const
;
std
::
vector
<
ir
::
LoweredFunc
>
LoweredFuncs
(
int
idx
)
const
;
std
::
vector
<
std
::
string
>
SourceCodes
()
const
;
std
::
string
SourceCode
(
int
idx
)
const
;
std
::
vector
<
std
::
string
>
SourcePtxs
()
const
;
std
::
string
SourcePtx
(
int
idx
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
RuntimeInstructions
()
const
;
const
std
::
unique_ptr
<
Instruction
>&
RuntimeInstruction
(
int
idx
)
const
;
std
::
unique_ptr
<
Program
>
RuntimeProgram
();
private:
std
::
vector
<
CompilationStatus
>
status_
;
std
::
vector
<
std
::
string
>
messages_
;
std
::
vector
<
std
::
optional
<
std
::
vector
<
ir
::
LoweredFunc
>>>
lowered_funcs_
;
std
::
vector
<
std
::
optional
<
std
::
string
>>
source_codes_
;
std
::
vector
<
std
::
optional
<
std
::
string
>>
source_ptxs_
;
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
instructions_
;
std
::
unique_ptr
<
Program
>
runtime_program_
;
int
size_
;
};
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/graph_test.cc
View file @
01a10755
...
...
@@ -20,7 +20,7 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
DECLARE_string
(
cinn_fusion_groups_graphviz_dir
);
PD_
DECLARE_string
(
cinn_fusion_groups_graphviz_dir
);
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/framework/instruction.cc
View file @
01a10755
...
...
@@ -22,8 +22,8 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/profiler.h"
DECLARE_bool
(
cinn_sync_run
);
DECLARE_string
(
cinn_self_check_accuracy
);
PD_
DECLARE_bool
(
cinn_sync_run
);
PD_
DECLARE_string
(
cinn_self_check_accuracy
);
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/framework/new_ir_compiler.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
#include <absl/types/variant.h>
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/utils/attribute_util.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/ir/core/builtin_type.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
CompatibleInfo
::
OP_NAMES
=
{
{
"pd.full"
,
"fill_constant"
},
{
"pd.matmul"
,
"matmul"
}};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
std
::
unique_ptr
<
Program
>
NewIRCompiler
::
Build
()
{
m_builder_
.
Clear
();
// NOTE(Aurelius84): Currently only support each op for one group
std
::
vector
<
std
::
vector
<::
ir
::
Operation
*>>
groups
;
for
(
auto
it
=
program_
.
block
()
->
begin
();
it
!=
program_
.
block
()
->
end
();
++
it
)
{
groups
.
push_back
({
*
it
});
}
VLOG
(
4
)
<<
"Groups size: "
<<
groups
.
size
();
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
for
(
int
i
=
0
;
i
<
groups
.
size
();
++
i
)
{
lowered_funcs
.
emplace_back
(
GetOpFunc
(
*
groups
[
i
][
0
],
i
));
}
for
(
auto
&&
lowered_func
:
lowered_funcs
)
{
ProcessFunction
(
lowered_func
);
}
compiler_
=
backends
::
Compiler
::
Create
(
target_
);
auto
build_module
=
m_builder_
.
Build
();
compiler_
->
Build
(
build_module
,
""
);
auto
instructions
=
BuildInstructions
(
groups
);
// TODO(Aurelius84): Instantiate all tensors on compile-time, which is
// controlled by 'options.with_instantiate_variables' in GraphCompiler.
// Moreover, it's better to implement InsertBufferHandlers() logic
// to automatically insert Malloc and Free instructions.
for
(
auto
&
name
:
scope_
->
var_names
())
{
std
::
string
var_name
({
name
.
data
(),
name
.
size
()});
VLOG
(
4
)
<<
"Instantiate "
<<
var_name
<<
" on compile-time"
;
auto
*
var
=
scope_
->
Var
<
Tensor
>
(
var_name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
tensor
->
mutable_data
(
target_
,
tensor
->
type
());
}
return
std
::
make_unique
<
Program
>
(
scope_
,
std
::
move
(
instructions
));
}
std
::
vector
<
ir
::
LoweredFunc
>
NewIRCompiler
::
GetOpFunc
(
const
::
ir
::
Operation
&
op
,
int
idx
)
{
std
::
vector
<
ir
::
Tensor
>
inputs
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
auto
op_name
=
op
.
name
();
VLOG
(
4
)
<<
"GetOpFunc for op: "
<<
op_name
;
// step 1: Deal with Oprands
for
(
int
i
=
0
;
i
<
op
.
num_operands
();
++
i
)
{
auto
in_value
=
op
.
operand_source
(
i
);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std
::
string
input_id
=
CompatibleInfo
::
kInputPrefix
+
std
::
to_string
(
std
::
hash
<::
ir
::
Value
>
()(
in_value
));
auto
type_info
=
in_value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
();
auto
in_shape
=
phi
::
vectorize
<
int
>
(
type_info
.
dims
());
auto
dtype
=
type_info
.
dtype
();
ir
::
Tensor
temp
=
lang
::
CreatePlaceHolder
(
in_shape
,
utils
::
ConvertIRType
(
dtype
),
input_id
);
inputs
.
push_back
(
temp
);
cinn_inputs
.
push_back
(
common
::
CINNValue
(
temp
));
}
for
(
auto
out_name
:
OpGetOutputNames
(
op
))
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
out_name
));
}
VLOG
(
4
)
<<
"inputs.size(): "
<<
inputs
.
size
();
// step 2: Deal with OpResult
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
for
(
int
i
=
0
;
i
<
op
.
num_results
();
++
i
)
{
auto
out_value
=
op
.
result
(
i
);
auto
type_info
=
out_value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
();
out_types
.
push_back
(
utils
::
ConvertIRType
(
type_info
.
dtype
()));
auto
out_shape
=
phi
::
vectorize
<
int
>
(
type_info
.
dims
());
out_shapes
.
push_back
(
std
::
move
(
out_shape
));
}
VLOG
(
4
)
<<
"out_types.size(): "
<<
out_types
.
size
();
NodeAttr
node_attrs
;
{
VLOG
(
4
)
<<
"op.attributes():"
<<
op
.
attributes
().
size
();
auto
attrs
=
utils
::
ConvertAttributes
(
op
.
attributes
());
node_attrs
.
node_name
=
CompatibleInfo
::
OP_NAMES
.
at
(
op_name
);
node_attrs
.
attr_store
=
std
::
move
(
attrs
);
}
auto
&
strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
// NOTE(Aurelius84): Do we need replace all hlir::framework Operator with
// ::ir::Program ?
const
hlir
::
framework
::
Operator
*
cinn_op
=
Operator
::
Get
(
CompatibleInfo
::
OP_NAMES
.
at
(
op_name
));
auto
impl
=
OpStrategy
::
SelectImpl
(
strategy
[
cinn_op
](
node_attrs
,
inputs
,
out_types
,
out_shapes
,
target_
));
common
::
CINNValuePack
C
=
impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
poly
::
StageMap
stages
=
C
.
back
();
// make sure all the tensors in the stages before schedule launch.
for
(
int
i
=
0
;
i
<
C
->
size
()
-
1
;
i
++
)
{
ir
::
Expr
temp
=
C
[
i
];
stages
->
InsertLazily
(
temp
.
as_tensor_ref
());
}
C
=
impl
->
fschedule
(
C
);
for
(
int
i
=
0
;
i
<
C
->
size
()
-
1
;
i
++
)
{
ir
::
Expr
temp
=
C
[
i
];
// checkout whether the tensor is with buffer.
if
((
!
temp
.
as_tensor_ref
()
->
buffer
.
defined
()
||
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
&&
!
stages
[
temp
.
as_tensor_ref
()]
->
inlined
())
{
inputs
.
push_back
(
temp
.
as_tensor_ref
());
}
}
auto
func
=
lang
::
LowerVec
(
GenOpFuncName
(
op
,
idx
),
stages
,
inputs
,
{},
{},
nullptr
,
target_
);
return
func
;
}
void
NewIRCompiler
::
ProcessFunction
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
)
{
for
(
auto
&&
func
:
lowered_funcs
)
{
for
(
auto
&&
arg
:
func
->
args
)
{
std
::
string
arg_name
=
arg
.
name
();
if
(
arg_name
[
0
]
==
'_'
)
arg_name
=
arg_name
.
substr
(
1
);
auto
*
var
=
scope_
->
FindVar
(
arg_name
);
// For argument buffer not in scope, create it.
if
(
!
var
&&
arg
.
is_buffer
())
{
auto
*
new_var
=
scope_
->
Var
<
Tensor
>
(
arg_name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
new_var
);
std
::
vector
<
Shape
::
dim_t
>
shape
;
for
(
auto
&
shape_dim
:
arg
.
buffer_arg
()
->
shape
)
{
CHECK
(
shape_dim
.
is_constant
());
shape
.
push_back
(
static_cast
<
int
>
(
shape_dim
.
get_constant
()));
}
tensor
->
Resize
(
Shape
{
shape
});
tensor
->
set_type
(
arg
.
buffer_arg
()
->
dtype
);
}
}
m_builder_
.
AddFunction
(
func
);
}
}
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
NewIRCompiler
::
BuildInstructions
(
const
std
::
vector
<
std
::
vector
<::
ir
::
Operation
*>>&
groups
)
{
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
instructions
;
for
(
int
idx
=
0
;
idx
<
groups
.
size
();
++
idx
)
{
// TODO(Aurelius84): only support single op in groups
auto
&
op
=
*
groups
[
idx
][
0
];
auto
instr_name
=
op
.
name
();
auto
instr
=
std
::
unique_ptr
<
Instruction
>
(
new
Instruction
(
target_
,
scope_
.
get
(),
OpGetInputNames
(
op
),
OpGetOutputNames
(
op
),
instr_name
));
auto
&
op_func_name
=
GenOpFuncName
(
op
,
idx
);
auto
*
fn_ptr
=
compiler_
->
Lookup
(
op_func_name
);
CHECK
(
fn_ptr
);
instr
->
SetLoweredFunc
(
reinterpret_cast
<
void
*>
(
fn_ptr
),
op_func_name
);
// As some instruction like reduce, will generate more than one kernel.
// So try to find the rest kernel, if it exists.
// SetSubKernels(instr.get(), op_func_name);
instr
->
Finalize
();
instructions
.
push_back
(
std
::
move
(
instr
));
}
return
instructions
;
}
const
std
::
string
&
NewIRCompiler
::
GenOpFuncName
(
const
::
ir
::
Operation
&
op
,
int
idx
)
{
// TODO(Aurelius84): . will raise compiler error in pd.xxx, need more
// elegant way to generate function name.
std
::
string
op_name
=
op
.
name
().
substr
(
3
)
+
"_"
+
std
::
to_string
(
idx
);
std
::
string
func_name
=
Context
::
Global
().
NewName
(
"fn_"
+
op_name
);
func_names_
.
try_emplace
(
op_name
,
func_name
);
return
func_names_
.
at
(
op_name
);
}
std
::
vector
<
std
::
string
>
NewIRCompiler
::
OpGetInputNames
(
const
::
ir
::
Operation
&
op
)
{
std
::
vector
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
repeat
;
for
(
int
i
=
0
;
i
<
op
.
num_operands
();
++
i
)
{
auto
value
=
op
.
operand_source
(
i
);
std
::
string
name
=
CompatibleInfo
::
kInputPrefix
+
std
::
to_string
(
std
::
hash
<::
ir
::
Value
>
()(
value
));
if
(
repeat
.
count
(
name
))
{
continue
;
}
repeat
.
insert
(
name
);
names
.
push_back
(
name
);
}
return
names
;
}
std
::
vector
<
std
::
string
>
NewIRCompiler
::
OpGetOutputNames
(
const
::
ir
::
Operation
&
op
)
{
std
::
vector
<
std
::
string
>
names
;
for
(
int
i
=
0
;
i
<
op
.
num_results
();
++
i
)
{
auto
value
=
op
.
result
(
i
);
std
::
string
name
=
CompatibleInfo
::
kOutputPrefix
+
std
::
to_string
(
std
::
hash
<::
ir
::
Value
>
()(
value
));
names
.
push_back
(
std
::
move
(
name
));
}
return
names
;
}
std
::
shared_ptr
<
Scope
>
BuildScope
(
const
Target
&
target
,
const
::
ir
::
Program
&
program
)
{
std
::
unordered_set
<::
ir
::
Value
>
visited
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
create_var
=
[
&
](
const
std
::
string
&
name_prefix
,
::
ir
::
Value
value
)
{
if
(
visited
.
count
(
value
)
>
0
)
return
;
visited
.
emplace
(
value
);
std
::
string
name
=
name_prefix
+
std
::
to_string
(
std
::
hash
<::
ir
::
Value
>
()(
value
));
auto
type_info
=
value
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
();
auto
*
var
=
scope
->
Var
<
Tensor
>
(
name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
// NOTE: can be replaced with phi::vectorized ?
std
::
vector
<
Shape
::
dim_t
>
shape
;
for
(
auto
i
=
0
;
i
<
type_info
.
dims
().
size
();
++
i
)
{
shape
.
push_back
(
Shape
::
dim_t
(
type_info
.
dims
()[
i
]));
}
tensor
->
Resize
(
Shape
{
shape
});
tensor
->
set_type
(
utils
::
ConvertIRType
(
type_info
.
dtype
()));
};
for
(
auto
it
=
program
.
block
()
->
begin
();
it
!=
program
.
block
()
->
end
();
++
it
)
{
for
(
auto
i
=
0
;
i
<
(
*
it
)
->
num_operands
();
++
i
)
{
auto
in_value
=
(
*
it
)
->
operand_source
(
i
);
create_var
(
CompatibleInfo
::
kInputPrefix
,
in_value
);
}
for
(
auto
i
=
0
;
i
<
(
*
it
)
->
num_results
();
++
i
)
{
auto
out_value
=
(
*
it
)
->
result
(
i
);
create_var
(
CompatibleInfo
::
kOutputPrefix
,
out_value
);
}
}
return
scope
;
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/new_ir_compiler.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include "paddle/cinn/common/macros.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
struct
CompatibleInfo
{
static
constexpr
char
*
kInputPrefix
=
"input_"
;
static
constexpr
char
*
kOutputPrefix
=
"output_"
;
// TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP
// macros or attempt to unify Op name with Paddle and CINN.
static
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
OP_NAMES
;
};
// TODO(Aurelius84): Need abstract this logic to implement Proxy for
// the co-existance with GraphCompiler.
class
NewIRCompiler
final
{
public:
NewIRCompiler
(
const
::
ir
::
Program
&
prog
,
const
Target
&
target
,
const
std
::
shared_ptr
<
Scope
>&
scope
)
:
program_
(
prog
),
m_builder_
(
"NewIR"
,
target
),
target_
(
target
),
scope_
(
scope
)
{}
std
::
unique_ptr
<
Program
>
Build
();
std
::
vector
<
ir
::
LoweredFunc
>
GetOpFunc
(
const
::
ir
::
Operation
&
op
,
int
idx
);
void
ProcessFunction
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
);
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
BuildInstructions
(
const
std
::
vector
<
std
::
vector
<::
ir
::
Operation
*>>&
groups
);
protected:
const
std
::
string
&
GenOpFuncName
(
const
::
ir
::
Operation
&
op
,
int
idx
);
std
::
vector
<
std
::
string
>
OpGetInputNames
(
const
::
ir
::
Operation
&
op
);
std
::
vector
<
std
::
string
>
OpGetOutputNames
(
const
::
ir
::
Operation
&
op
);
private:
CINN_DISALLOW_COPY_AND_ASSIGN
(
NewIRCompiler
);
const
::
ir
::
Program
&
program_
;
ir
::
Module
::
Builder
m_builder_
;
std
::
unique_ptr
<
backends
::
Compiler
>
compiler_
{
nullptr
};
Target
target_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
func_names_
;
};
std
::
shared_ptr
<
Scope
>
BuildScope
(
const
Target
&
,
const
::
ir
::
Program
&
);
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/op_lowering.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/op_lowering_util.h"
#include "paddle/cinn/hlir/op/external_api_registry.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
DECLARE_bool
(
cinn_use_cuda_vectorize
);
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
using
common
::
bfloat16
;
using
common
::
float16
;
using
framework
::
Node
;
using
framework
::
NodeData
;
using
framework
::
OpPatternKind
;
using
framework
::
shape_t
;
using
framework
::
StrategyFunction
;
using
common
::
Type
;
using
cinn
::
hlir
::
op
::
ExternalApiRegistry
;
OpLowerer
::
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
const
Target
&
target
)
:
type_dict_
(
type_dict
),
shape_dict_
(
shape_dict
),
target_
(
target
)
{}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
Lower
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
)
{
VLOG
(
3
)
<<
"Lowering Group : "
<<
group
->
group_id
<<
" , Op Pattern : "
<<
group
->
op_pattern_kind
;
group
->
input_names
.
clear
();
group
->
output_names
.
clear
();
switch
(
group
->
op_pattern_kind
)
{
case
framework
::
kElementWise
:
case
framework
::
kBroadcast
:
case
framework
::
kInjective
:
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
ElementwiseScheduleDetermineFunction
);
case
framework
::
kReduction
:
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
ReduceScheduleDetermineFunction
);
case
framework
::
kOutFusible
:
LOG
(
FATAL
)
<<
"Group Pattern Kind kOutFusible Is Not Implemented!"
;
case
framework
::
kNonFusible
:
return
LowerGroup
(
group
,
apply_op_schedule
,
apply_group_schedule
,
&
OpLowerer
::
NonFusibleScheduleDetermineFunction
);
default:
LOG
(
FATAL
)
<<
"Group Pattern Kind Is Unknown!"
;
}
}
bool
OpLowerer
::
ElementwiseScheduleDetermineFunction
(
Node
*
node
)
{
return
true
;
}
bool
OpLowerer
::
ReduceScheduleDetermineFunction
(
Node
*
node
)
{
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
return
op_pattern_dict
[
node
->
op
()]
==
framework
::
kReduction
;
}
bool
OpLowerer
::
NonFusibleScheduleDetermineFunction
(
Node
*
node
)
{
return
true
;
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerGroup
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
,
ScheduleDetermineFunction
schedule_determine_func
)
{
// 1.Do compute, lower and schedule for each op.
VLOG
(
3
)
<<
"group->fused_sub_groups.size() is : "
<<
group
->
fused_sub_groups
.
size
();
std
::
vector
<
Node
*>
nodes
=
group
->
CollectNodes
();
if
(
nodes
.
size
()
==
1
&&
nodes
[
0
]
->
op
()
->
name
==
"custom_call"
)
{
return
LowerCustomCall
(
group
);
}
std
::
vector
<
ir
::
Tensor
>
group_func_arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
bool
do_op_schedule
=
apply_group_schedule
||
apply_op_schedule
;
std
::
vector
<
ir
::
Expr
>
func_bodies
=
LowerOps
(
nodes
,
do_op_schedule
,
schedule_determine_func
,
&
group_func_arg_tensors
,
&
tensor_map
);
// 2.Do group schedule.
ir
::
ModuleExpr
mod_expr
(
func_bodies
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
VLOG
(
3
)
<<
"After lower, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
if
(
apply_group_schedule
)
{
DoGroupSchedule
(
ir_sch
,
group
,
tensor_map
);
VLOG
(
3
)
<<
"After group schedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
}
// 3.Do post-processing,
// including preparing function args and temporary variables,
// applying low-level optimization passes, etc.
return
PostProcess
(
group
,
tensor_map
,
do_op_schedule
,
&
ir_sch
,
&
group_func_arg_tensors
);
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
LowerCustomCall
(
const
GroupPtr
&
group
)
{
std
::
vector
<
Node
*>
nodes
=
group
->
CollectNodes
();
CHECK_EQ
(
nodes
.
size
(),
1
);
Node
*
node
=
nodes
[
0
];
std
::
vector
<
ir
::
Tensor
>
op_func_arg_tensors
;
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>
tensor_map
;
for
(
auto
&
node_data
:
GetInputNodeData
(
node
))
{
CHECK
(
node_data
);
ir
::
Tensor
tensor
;
if
(
!
tensor_map
.
count
(
node_data
->
id
()))
{
tensor
=
GetTensor
(
node_data
,
this
->
type_dict_
,
this
->
shape_dict_
);
// record tensor.
tensor_map
[
node_data
->
id
()]
=
tensor
;
// input name.
group
->
input_names
.
push_back
(
node_data
->
id
());
}
else
{
tensor
=
tensor_map
[
node_data
->
id
()];
}
op_func_arg_tensors
.
push_back
(
tensor
);
}
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
auto
node_datas
=
GetAllNodeData
(
node
);
for
(
auto
node_data
:
node_datas
)
{
group
->
output_names
.
push_back
(
node_data
->
id
());
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
}
auto
&
cinn_strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
auto
impl
=
OpStrategy
::
SelectImpl
(
cinn_strategy
[
node
->
op
()](
node
->
attrs
,
op_func_arg_tensors
,
out_types
,
out_shapes
,
target_
));
std
::
string
external_api
;
if
(
node
->
attrs
.
attr_store
.
count
(
"custom_call"
))
{
external_api
=
absl
::
get
<
std
::
string
>
(
node
->
attrs
.
attr_store
.
at
(
"custom_call"
));
}
else
{
external_api
=
ExternalApiRegistry
::
Global
()
->
GetExternalApi
(
node
,
target_
);
}
std
::
vector
<
common
::
CINNValue
>
compute_args
=
{
common
::
CINNValue
(
group
->
GetFuncName
()),
common
::
CINNValue
(
external_api
)};
common
::
CINNValuePack
pack
=
impl
->
fcompute
(
common
::
CINNValuePack
{
compute_args
});
CHECK_EQ
(
pack
.
size
(),
1UL
);
// reset input names as extern api input args can't be remove duplicate.
group
->
input_names
.
clear
();
for
(
auto
&
inode
:
node
->
inlinks_in_order
())
{
group
->
input_names
.
push_back
(
inode
->
source
()
->
as
<
NodeData
>
()
->
id
());
}
return
{
pack
[
0
].
operator
ir
::
Expr
().
as_lowered_func_ref
()};
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
)
{
// 1.Prepare function args
group
->
input_names
.
clear
();
std
::
vector
<
ir
::
Argument
>
group_func_args
;
std
::
unordered_set
<
std
::
string
>
arg_name_set
;
for
(
auto
&
arg_tensor
:
*
group_func_arg_tensors
)
{
// input node data name.
group
->
input_names
.
push_back
(
arg_tensor
->
name
);
// input args
group_func_args
.
emplace_back
(
arg_tensor
->
buffer
,
ir
::
Argument
::
IO
::
kInput
);
arg_name_set
.
insert
(
arg_tensor
->
buffer
->
name
);
}
group
->
output_names
.
clear
();
for
(
auto
&
node
:
group
->
output_nodes
)
{
// collect all output tensor.
for
(
auto
node_data
:
GetAllNodeData
(
node
))
{
std
::
string
output_node_data_name
=
node_data
->
id
();
group
->
output_names
.
push_back
(
output_node_data_name
);
// CHECK(tensor_map.count(output_node_data_name)) << "Can't find output
// tensor " << output_node_data_name;
if
(
tensor_map
.
count
(
output_node_data_name
)
==
0
)
{
continue
;
}
auto
tensor
=
tensor_map
.
at
(
output_node_data_name
);
if
(
arg_name_set
.
count
(
tensor
->
buffer
->
name
)
!=
0
)
{
continue
;
}
// output arg tensors
group_func_arg_tensors
->
push_back
(
tensor
);
// output args
group_func_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
arg_name_set
.
insert
(
tensor
->
buffer
->
name
);
}
}
if
(
!
done_op_schedule
)
{
std
::
unordered_set
<
std
::
string
>
args_set
;
for
(
auto
arg
:
group_func_args
)
{
args_set
.
insert
(
arg
.
name
());
}
for
(
auto
&
tensor_pair
:
tensor_map
)
{
if
(
args_set
.
count
(
"_"
+
tensor_pair
.
second
->
name
))
{
continue
;
}
group_func_arg_tensors
->
push_back
(
tensor_pair
.
second
);
// use the underlying tensor name to be consistent with the argument name
// in the lowered function
group
->
output_names
.
push_back
(
tensor_pair
.
second
->
name
);
group_func_args
.
emplace_back
(
tensor_pair
.
second
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
}
auto
func_body
=
ir_sch
->
GetModule
().
GetExprs
().
at
(
0
);
#ifdef CINN_WITH_CUDA
optim
::
OptimizeExprGPU
(
&
(
func_body
));
#endif
// 2.Prepare temp buffers
poly
::
StageMap
stages
;
auto
temp_buffers
=
lang
::
GetTempBuffers
(
*
group_func_arg_tensors
,
stages
,
func_body
);
// 3.Building LoweredFunc
auto
func
=
ir
::
_LoweredFunc_
::
Make
(
group
->
GetFuncName
(),
group_func_args
,
ir_sch
->
GetModule
().
GetExprs
().
at
(
0
),
temp_buffers
);
if
(
!
done_op_schedule
)
{
func
->
PrepareBufferCastExprs
();
}
// 4.Apply low level pass
func
=
optim
::
Optimize
(
Expr
(
func
),
target_
,
false
).
as_lowered_func_ref
();
return
{
func
};
}
std
::
vector
<
ir
::
Expr
>
OpLowerer
::
LowerOps
(
const
std
::
vector
<
Node
*>&
nodes
,
bool
apply_op_schedule
,
ScheduleDetermineFunction
schedule_determine_func
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
)
{
auto
&
strategy
=
Operator
::
GetAttrs
<
StrategyFunction
>
(
"CINNStrategy"
);
std
::
vector
<
Expr
>
func_bodies
;
for
(
Node
*
node
:
nodes
)
{
// 1.Select Op impl
std
::
vector
<
Type
>
out_types
;
std
::
vector
<
std
::
vector
<
int
>>
out_shapes
;
std
::
vector
<
NodeData
*>
node_datas
=
GetAllNodeData
(
node
);
for
(
const
auto
&
node_data
:
node_datas
)
{
out_types
.
push_back
(
this
->
type_dict_
.
at
(
node_data
->
id
()));
out_shapes
.
push_back
(
this
->
shape_dict_
.
at
(
node_data
->
id
()));
}
std
::
vector
<
ir
::
Tensor
>
op_func_arg_tensors
=
std
::
move
(
CollectInputTensor
(
node
,
this
->
type_dict_
,
this
->
shape_dict_
,
group_func_arg_tensors
,
tensor_map
));
auto
op_impl
=
OpStrategy
::
SelectImpl
(
strategy
[
node
->
op
()](
node
->
attrs
,
op_func_arg_tensors
,
out_types
,
out_shapes
,
this
->
target_
));
// 2.Perform the lower process of Op
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
DoOpLower
(
op_impl
,
node
,
tensor_map
,
&
op_func_arg_tensors
);
if
(
apply_op_schedule
&&
(
this
->*
schedule_determine_func
)(
node
))
{
// 3.Perform the schedule of Op
func_bodies
.
push_back
(
DoOpSchedule
(
op_impl
,
op_func_arg_tensors
,
funcs
));
}
else
{
for
(
const
ir
::
LoweredFunc
&
func
:
funcs
)
{
func_bodies
.
push_back
(
func
->
body
);
}
}
}
return
func_bodies
;
}
std
::
vector
<
ir
::
LoweredFunc
>
OpLowerer
::
DoOpLower
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
)
{
VLOG
(
4
)
<<
"Do lower with Compute, op: "
<<
node
->
op
()
->
name
;
std
::
vector
<
common
::
CINNValue
>
cinn_inputs
;
for
(
const
ir
::
Tensor
&
tensor
:
*
op_func_arg_tensors
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
ir
::
Expr
(
tensor
)));
}
// set tensor name = node data name
std
::
vector
<
NodeData
*>
node_datas
=
GetAllNodeData
(
node
);
for
(
const
NodeData
*
node_data
:
node_datas
)
{
cinn_inputs
.
push_back
(
common
::
CINNValue
(
node_data
->
id
()));
}
// 1.Do compute
common
::
CINNValuePack
pack
=
op_impl
->
fcompute
(
common
::
CINNValuePack
{
cinn_inputs
});
poly
::
StageMap
tmp_stages
=
pack
.
back
();
std
::
string
post
=
""
;
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
Expr
expr
=
pack
[
idx
];
// Insert the output tensor defined by Compute into the tensor_map
if
(
pack
.
size
()
-
1
>
node_datas
.
size
())
{
// Some nodes may output multiple temp tensors in their Compute
// definition, but only one output node_data in the graph, and we use id +
// "_0"/"_1" as key.
(
*
tensor_map
)[
node_datas
[
0
]
->
id
()
+
post
]
=
expr
.
as_tensor_ref
();
post
=
"_"
+
std
::
to_string
(
idx
);
}
else
{
// If the number of output tensors defined by Compute is less equal than
// the output node_data on the graph, then there is a one-to-one
// correspondence, and the redundant output node_data contact empty.
(
*
tensor_map
)[
node_datas
[
idx
]
->
id
()]
=
expr
.
as_tensor_ref
();
}
// Insert output tensors into function arg
if
(
!
expr
.
as_tensor_ref
()
->
buffer
.
defined
()
||
this
->
target_
!=
common
::
DefaultNVGPUTarget
())
{
op_func_arg_tensors
->
push_back
(
expr
.
as_tensor_ref
());
expr
.
as_tensor_ref
()
->
WithBuffer
();
}
}
// 2.Do lower
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"fn_"
+
node
->
id
(),
tmp_stages
,
*
op_func_arg_tensors
,
{},
{},
nullptr
,
this
->
target_
,
true
);
VLOG
(
4
)
<<
"Lower op: "
<<
node
->
op
()
->
name
<<
", get "
<<
funcs
.
size
()
<<
" LoweredFunc:
\n
"
;
op_func_arg_tensors
->
clear
();
for
(
int
idx
=
0
;
idx
<
pack
.
size
()
-
1
;
++
idx
)
{
CHECK
(
pack
[
idx
].
is_tensor
());
op_func_arg_tensors
->
push_back
(
pack
[
idx
].
operator
ir
::
Expr
().
as_tensor_ref
());
}
return
funcs
;
}
ir
::
Expr
OpLowerer
::
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
)
{
VLOG
(
4
)
<<
"Do op schedule"
;
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
// 1.Collect tensors
for
(
const
ir
::
Tensor
&
op_func_arg_tensor
:
op_func_arg_tensors
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
op_func_arg_tensor
));
}
// 2.Collect bodies to be scheduled
for
(
const
ir
::
LoweredFunc
&
func
:
lowered_funcs
)
{
schedule_inputs
.
push_back
(
common
::
CINNValue
(
func
->
body
));
}
// 3.Do schedule on AST
common
::
CINNValuePack
expr_pack
=
op_impl
->
fschedule
(
common
::
CINNValuePack
{
schedule_inputs
});
VLOG
(
4
)
<<
"After op schedule: "
<<
expr_pack
[
0
].
operator
ir
::
Expr
();
return
expr_pack
[
0
].
operator
ir
::
Expr
();
}
// group schedule
ir
::
Expr
OpLowerer
::
DoGroupSchedule
(
ir
::
IRSchedule
&
ir_sch
,
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
)
{
// topological order.
auto
nodes_set
=
group
->
NodeSet
();
auto
v_consumers
=
BuildVirtualConsumer
(
group
,
this
->
shape_dict_
);
auto
nodes_in_order
=
BFSTopologicalOrderWithPriority
(
group
,
v_consumers
,
this
->
shape_dict_
);
// find reducer.
std
::
unordered_set
<
Node
*>
nodes_inline
;
auto
greducer
=
FindGlobalReducer
(
nodes_in_order
);
auto
&
op_pattern_dict
=
Operator
::
GetAttrs
<
OpPatternKind
>
(
"OpPattern"
);
// do schedule
for
(
auto
node
:
nodes_in_order
)
{
VLOG
(
4
)
<<
"Try FUSION "
<<
node
->
op
()
->
name
;
// consumers.
auto
consumers
=
GetConsumersInSet
(
node
,
nodes_set
);
const
Node
*
reducer
=
greducer
?
FindNearestReducer
(
node
,
nodes_set
)
:
greducer
;
if
(
!
reducer
&&
greducer
)
{
reducer
=
v_consumers
.
count
(
node
)
?
v_consumers
.
find
(
node
)
->
second
:
reducer
;
if
(
reducer
&&
op_pattern_dict
[
reducer
->
op
()]
!=
framework
::
kReduction
)
{
reducer
=
nullptr
;
}
}
auto
masters
=
GetMasters
(
node
,
nodes_inline
,
nodes_set
);
// node can be inline.
if
(
CanbeInline
(
node
,
consumers
,
reducer
,
masters
,
group
,
nodes_set
,
this
->
shape_dict_
))
{
VLOG
(
3
)
<<
"Before compute inline, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
auto
block
=
ir_sch
.
GetBlock
(
GetNodeData
(
node
)
->
id
());
ir
::
ComputeInlineChecker
checker
(
ir_sch
,
block
);
if
(
!
checker
.
Check
())
{
checker
.
BuildDataDependency
();
continue
;
}
// if exist global reduce node.
if
(
greducer
)
{
auto
loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
node
)
->
id
());
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kElementWise
)
{
ir_sch
.
FlattenLoops
(
loops
,
true
);
}
else
{
ir_sch
.
FlattenLoops
(
loops
,
false
);
}
}
ir_sch
.
ComputeInline
(
block
);
nodes_inline
.
insert
(
node
);
VLOG
(
3
)
<<
"After compute inline, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
continue
;
}
// find master to computeat.
auto
master
=
GetMasterToComputeAt
(
node
,
nodes_in_order
,
nodes_inline
,
nodes_set
,
v_consumers
,
this
->
shape_dict_
);
// assign to reducer/master loop.
if
(
reducer
)
{
VLOG
(
3
)
<<
"Before assign node "
<<
node
->
id
()
<<
" into vertical link reducer "
<<
reducer
->
id
()
<<
", ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// if node is vertical with reduce, loop assign reducer.
LoopAssignReduce
(
ir_sch
,
node
,
reducer
,
this
->
target_
,
tensor_map
,
this
->
shape_dict_
);
}
else
if
(
greducer
)
{
auto
greducer_out_shape
=
this
->
shape_dict_
.
at
(
greducer
->
outlinks_in_order
()[
0
]
->
sink
()
->
id
());
auto
node_out_shape
=
this
->
shape_dict_
.
at
(
node
->
outlinks_in_order
()[
0
]
->
sink
()
->
id
());
if
(
std
::
accumulate
(
greducer_out_shape
.
begin
(),
greducer_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
())
!=
std
::
accumulate
(
node_out_shape
.
begin
(),
node_out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()))
{
LoopAssignReduce
(
ir_sch
,
node
,
greducer
,
this
->
target_
,
tensor_map
,
this
->
shape_dict_
);
}
else
{
VLOG
(
3
)
<<
"Before assign node "
<<
node
->
id
()
<<
" into horizontal link reducer "
<<
greducer
->
id
()
<<
", ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// if node is horizontal with reduce or node is reduce, loop assign
//
// master.
auto
loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
node
)
->
id
());
if
(
op_pattern_dict
[
node
->
op
()]
==
framework
::
kElementWise
)
{
ir_sch
.
FlattenLoops
(
loops
,
true
);
}
else
if
(
op_pattern_dict
[
node
->
op
()]
!=
framework
::
kReduction
)
{
ir_sch
.
FlattenLoops
(
loops
,
false
);
}
if
(
master
&&
op_pattern_dict
[
node
->
op
()]
!=
framework
::
kReduction
)
{
auto
master_loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
master
)
->
id
());
std
::
vector
<
int
>
splits
;
for
(
auto
loop
:
master_loops
)
{
splits
.
push_back
(
loop
.
As
<
ir
::
For
>
()
->
extent
.
as_int32
());
}
loops
=
ir_sch
.
GetLoops
(
GetNodeData
(
node
)
->
id
());
ir_sch
.
Split
(
loops
[
0
],
splits
);
}
}
}
VLOG
(
3
)
<<
"Before loop fusion, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
// do loop fuse.
LoopComputeAt
(
ir_sch
,
node
,
master
?
master
:
nodes_in_order
.
front
(),
group
,
this
->
shape_dict_
,
tensor_map
);
VLOG
(
3
)
<<
"After loop fusion, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
}
// do vectorize
auto
all_blocks
=
ir_sch
.
GetAllBlocks
();
VLOG
(
4
)
<<
"Size of blocks: "
<<
all_blocks
.
size
();
VLOG
(
4
)
<<
"Op Pattern : "
<<
group
->
op_pattern_kind
;
// only support first block?
auto
block
=
all_blocks
[
0
];
CHECK
(
block
->
as
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
block
->
as
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
->
as
<
ir
::
ScheduleBlock
>
());
auto
is_tensor_block
=
true
;
auto
tensor_name
=
block
->
as
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
->
as
<
ir
::
ScheduleBlock
>
()
->
name
;
if
(
!
tensor_map
.
count
(
tensor_name
))
{
is_tensor_block
=
false
;
}
if
(
FLAGS_cinn_use_cuda_vectorize
&&
is_tensor_block
&&
(
group
->
op_pattern_kind
==
framework
::
kElementWise
||
group
->
op_pattern_kind
==
framework
::
kInjective
||
group
->
op_pattern_kind
==
framework
::
kBroadcast
))
{
// auto loops = ir_sch.GetLoops(GetNodeData(node)->id());
auto
loops
=
ir_sch
.
GetLoops
(
block
);
VLOG
(
4
)
<<
"Op Pattern : "
<<
loops
.
size
();
if
(
loops
.
size
()
>=
1
)
{
VLOG
(
4
)
<<
"Before vectorize, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
auto
loop_inner
=
loops
.
back
();
int
vector_width
=
1
;
auto
psize
=
ir
::
GetLoopExtent
(
loop_inner
);
// get dtype of vectorized var
auto
dtype
=
this
->
type_dict_
.
at
(
tensor_name
);
VLOG
(
4
)
<<
tensor_name
<<
" dtype "
<<
dtype
;
if
(
psize
%
8
==
0
&&
(
dtype
.
is_float16
()
||
dtype
.
is_bfloat16
()))
{
vector_width
=
8
;
}
else
if
(
psize
%
4
==
0
)
{
vector_width
=
4
;
}
else
if
(
psize
%
2
==
0
)
{
vector_width
=
2
;
}
if
(
vector_width
>
1
)
{
auto
splited
=
ir_sch
.
Split
(
loop_inner
,
{
-
1
,
vector_width
});
splited
[
0
].
As
<
ir
::
For
>
()
->
set_bind_info
(
loop_inner
.
As
<
ir
::
For
>
()
->
bind_info
());
splited
[
1
].
As
<
ir
::
For
>
()
->
set_serial
();
ir_sch
.
Vectorize
(
splited
[
1
],
vector_width
);
}
VLOG
(
4
)
<<
"After vectorize, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
}
}
VLOG
(
3
)
<<
"Before Sync IRLowerOp schedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
SyncThreadWithShared
(
ir_sch
,
group
,
nodes_inline
,
nodes_set
,
this
->
shape_dict_
,
tensor_map
);
VLOG
(
4
)
<<
"After IRSchedule, ir is:
\n
"
<<
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
return
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
);
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/op_lowering.h
View file @
01a10755
// Copyright (c) 202
2 CINN
Authors. All Rights Reserved.
// Copyright (c) 202
3 PaddlePaddle
Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -13,166 +13,78 @@
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/hlir/framework/op_lowering_impl.h"
#include "paddle/cinn/hlir/framework/op_lowering_impl_base.h"
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/lang/packed_func.h"
// Fusion Op lowering, there are four kinds of lowering function:
// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusable,NonFusible.
// Elementwise/Broadcast/Injective Ops is with same shcedule.
// Reduce,OutEWiseFusable,NonFusible are using different schedule.
#ifndef CINN_WITH_ONLY
#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h"
#endif
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
using
GroupPtr
=
std
::
shared_ptr
<
Graph
::
Group
>
;
using
common
::
Target
;
using
GroupPtr
=
std
::
shared_ptr
<
hlir
::
framework
::
Graph
::
Group
>
;
class
OpLowerer
;
typedef
bool
(
OpLowerer
::*
ScheduleDetermineFunction
)(
Node
*
);
template
<
typename
T
>
class
OpLowerer
{
public:
OpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
Target
&
);
explicit
OpLowerer
(
OpLowererImplBase
<
T
>*
impl
)
{
impl_
.
reset
(
impl
);
}
~
OpLowerer
()
{}
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
const
GroupPtr
&
group
,
std
::
vector
<
ir
::
LoweredFunc
>
Lower
(
const
T
&
group
,
bool
apply_op_schedule
=
true
,
bool
apply_group_schedule
=
true
);
bool
apply_group_schedule
=
true
,
bool
apply_pass
=
true
)
{
return
impl_
->
Lower
(
group
,
apply_op_schedule
,
apply_group_schedule
,
apply_pass
);
}
std
::
vector
<
std
::
pair
<
ir
::
SymbolicPredicate
,
ir
::
LoweredFunc
>>
BucketLower
(
const
T
&
group
,
bool
apply_op_schedule
=
false
,
bool
apply_group_schedule
=
true
,
bool
apply_pass
=
true
)
{
return
impl_
->
BucketLower
(
group
,
apply_op_schedule
,
apply_group_schedule
,
apply_pass
);
}
private:
/**
* @brief Lower a group to CINN IR.
* @param group The group to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param apply_group_schedule Whether to schedule at group level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
LowerGroup
(
const
GroupPtr
&
group
,
bool
apply_op_schedule
,
bool
apply_group_schedule
,
ScheduleDetermineFunction
schedule_determine_func
);
/**
* @brief Lower a group composed of CustomCall Op.
* @param group The group to be lowered.
* @return The lowered funcs.
*/
std
::
vector
<
ir
::
LoweredFunc
>
LowerCustomCall
(
const
GroupPtr
&
group
);
/**
* @brief Post processing, including preparing function args and temporary
* variables, applying low-level optimization passes, etc.
* @param group The group to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param done_op_schedule Mark whether the Op level schedule has been
* applied.
* @param ir_sch The IRSchedule object of group.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @return The lowered funcs after the post processing.
*/
std
::
vector
<
ir
::
LoweredFunc
>
PostProcess
(
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
,
bool
done_op_schedule
,
ir
::
IRSchedule
*
ir_sch
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
);
/**
* @brief Lower an Op set to CINN IR.
* Compute, Lower and optional Schedule will be performed one by one
* for each Op.
* @param nodes The Op nodes to be lowered.
* @param apply_op_schedule Whether to schedule at Op level.
* @param schedule_determine_func Function used to determine which Ops to
* schedule.
* @param group_func_arg_tensors Tensors used as the group function arguments.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func bodies of Op set.
*/
std
::
vector
<
ir
::
Expr
>
LowerOps
(
const
std
::
vector
<
Node
*>&
nodes
,
bool
apply_op_schedule
,
ScheduleDetermineFunction
schedule_determine_func
,
std
::
vector
<
ir
::
Tensor
>*
group_func_arg_tensors
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
);
/**
* @brief Lower an Op to CINN IR. The Compute and Lower processes will be
* called sequentially.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param node The Op node to be lowered.
* @param tensor_map All tensors used for calculating the group.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @return The lowered func of the Op node.
*/
std
::
vector
<
ir
::
LoweredFunc
>
DoOpLower
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
Node
*
node
,
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>*
tensor_map
,
std
::
vector
<
ir
::
Tensor
>*
op_func_arg_tensors
);
/**
* @brief Apply schedule on an Op.
* @param op_impl The Op implementation defining Compute and Schedule.
* @param op_func_arg_tensors Tensors used as the Op function arguments.
* @param lowered_funcs The lowered funcs of an Op to be scheduled.
* @return The lowered func body after schedule of the Op.
*/
ir
::
Expr
DoOpSchedule
(
std
::
shared_ptr
<
hlir
::
framework
::
OpImpl
>
op_impl
,
const
std
::
vector
<
ir
::
Tensor
>&
op_func_arg_tensors
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
);
/**
* @brief Apply schedule on a group.
* @param ir_sch The IRSchedule containing the entire group's lowered func
* bodies.
* @param group The group to be scheduled.
* @param tensor_map All tensors used for calculating the group.
* @return The lowered func body after schedule of the group.
*/
ir
::
Expr
DoGroupSchedule
(
ir
::
IRSchedule
&
ir_sch
,
// NOLINT
const
GroupPtr
&
group
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Tensor
>&
tensor_map
);
// Functions used to determine which Ops to schedule at op level, define a
// policy for each type of group.
inline
bool
ReduceScheduleDetermineFunction
(
Node
*
node
);
inline
bool
ElementwiseScheduleDetermineFunction
(
Node
*
node
);
inline
bool
NonFusibleScheduleDetermineFunction
(
Node
*
node
);
private:
Target
target_
;
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict_
;
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict_
;
// fucntion name prefix
const
std
::
string
func_name_prefix
=
"fn_"
;
std
::
shared_ptr
<
OpLowererImplBase
<
T
>>
impl_
;
};
template
<
typename
T
=
GroupPtr
>
OpLowerer
<
T
>
CreateOpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
,
const
Target
&
);
template
<
>
inline
OpLowerer
<
GroupPtr
>
CreateOpLowerer
(
const
absl
::
flat_hash_map
<
std
::
string
,
Type
>&
type_dict
,
const
absl
::
flat_hash_map
<
std
::
string
,
shape_t
>&
shape_dict
,
const
Target
&
target
)
{
auto
*
impl_base
=
new
OpLowererImpl
(
type_dict
,
shape_dict
,
target
);
return
OpLowerer
<
GroupPtr
>
(
impl_base
);
}
#ifndef CINN_WITH_ONLY
template
<
typename
T
=
pir
::
GroupPtr
>
OpLowerer
<
T
>
CreateOpLowerer
(
const
Target
&
);
template
<
>
inline
OpLowerer
<
pir
::
GroupPtr
>
CreateOpLowerer
(
const
Target
&
target
)
{
auto
*
impl_base
=
new
pir
::
OpLowererImpl
(
target
);
return
OpLowerer
<
pir
::
GroupPtr
>
(
impl_base
);
}
#endif
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
Prev
1
…
12
13
14
15
16
17
18
19
20
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment