Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
833 additions
and
1312 deletions
+833
-1312
paddle/cinn/hlir/framework/CMakeLists.txt
paddle/cinn/hlir/framework/CMakeLists.txt
+9
-9
paddle/cinn/hlir/framework/accuracy_checker.cc
paddle/cinn/hlir/framework/accuracy_checker.cc
+1
-1
paddle/cinn/hlir/framework/accuracy_checker_test.cc
paddle/cinn/hlir/framework/accuracy_checker_test.cc
+1
-1
paddle/cinn/hlir/framework/compile_error.cc
paddle/cinn/hlir/framework/compile_error.cc
+41
-0
paddle/cinn/hlir/framework/compile_error.h
paddle/cinn/hlir/framework/compile_error.h
+68
-0
paddle/cinn/hlir/framework/convert_to_dialect.cc
paddle/cinn/hlir/framework/convert_to_dialect.cc
+0
-55
paddle/cinn/hlir/framework/convert_to_dialect.h
paddle/cinn/hlir/framework/convert_to_dialect.h
+0
-33
paddle/cinn/hlir/framework/graph.cc
paddle/cinn/hlir/framework/graph.cc
+15
-6
paddle/cinn/hlir/framework/graph.h
paddle/cinn/hlir/framework/graph.h
+6
-3
paddle/cinn/hlir/framework/graph_compiler.cc
paddle/cinn/hlir/framework/graph_compiler.cc
+99
-57
paddle/cinn/hlir/framework/graph_compiler.h
paddle/cinn/hlir/framework/graph_compiler.h
+25
-38
paddle/cinn/hlir/framework/graph_compiler_test.cc
paddle/cinn/hlir/framework/graph_compiler_test.cc
+74
-8
paddle/cinn/hlir/framework/graph_compiler_util.cc
paddle/cinn/hlir/framework/graph_compiler_util.cc
+289
-0
paddle/cinn/hlir/framework/graph_compiler_util.h
paddle/cinn/hlir/framework/graph_compiler_util.h
+149
-0
paddle/cinn/hlir/framework/graph_test.cc
paddle/cinn/hlir/framework/graph_test.cc
+1
-1
paddle/cinn/hlir/framework/instruction.cc
paddle/cinn/hlir/framework/instruction.cc
+2
-2
paddle/cinn/hlir/framework/new_ir_compiler.cc
paddle/cinn/hlir/framework/new_ir_compiler.cc
+0
-280
paddle/cinn/hlir/framework/new_ir_compiler.h
paddle/cinn/hlir/framework/new_ir_compiler.h
+0
-77
paddle/cinn/hlir/framework/op_lowering.cc
paddle/cinn/hlir/framework/op_lowering.cc
+0
-600
paddle/cinn/hlir/framework/op_lowering.h
paddle/cinn/hlir/framework/op_lowering.h
+53
-141
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/hlir/framework/CMakeLists.txt
View file @
01a10755
add_subdirectory
(
pir
)
core_gather_headers
()
core_gather_headers
()
gather_srcs
(
gather_srcs
(
...
@@ -12,22 +13,21 @@ gather_srcs(
...
@@ -12,22 +13,21 @@ gather_srcs(
program.cc
program.cc
parallel_compiler.cc
parallel_compiler.cc
graph_compiler.cc
graph_compiler.cc
graph_compiler_util.cc
graph.cc
graph.cc
node.cc
node.cc
pass.cc
pass.cc
op_strategy.cc
op_strategy.cc
op_lowering.cc
op_lowering_util.cc
op_lowering_util.cc
op_lowering_impl.cc
accuracy_checker.cc
accuracy_checker.cc
visualize_helper.cc
)
visualize_helper.cc
compile_error.cc
)
# TODO(Aurelius84):
new_
ir_compiler depends on p
d
_dialect and could
# TODO(Aurelius84):
p
ir_compiler depends on
o
p_dialect
_vjp
and could
# not found under CINN_ONLY mode
# not found under CINN_ONLY mode
if
(
NOT CINN_ONLY
)
if
(
NOT CINN_ONLY
)
cinn_cc_library
(
new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi
cinn_cc_library
(
pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp
)
pd_dialect
)
cinn_cc_library
(
convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi
cinn_dialect
)
endif
()
endif
()
if
(
WITH_CUDA
)
if
(
WITH_CUDA
)
...
@@ -52,5 +52,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore)
...
@@ -52,5 +52,5 @@ cinn_cc_test(test_hlir_framework_op SRCS op_test.cc DEPS cinncore)
cinn_cc_test
(
test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc
cinn_cc_test
(
test_hlir_framework_print_graph_pass SRCS print_graph_pass_test.cc
DEPS cinncore
)
DEPS cinncore
)
cinn_cc_test
(
test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore
)
cinn_cc_test
(
test_hlir_framework_graph SRCS graph_test.cc DEPS cinncore
)
cinn_cc_test
(
test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
#cinn_cc_test(test_hlir_framework_graph_compiler SRCS graph_compiler_test.cc
DEPS cinncore)
DEPS cinncore
)
paddle/cinn/hlir/framework/accuracy_checker.cc
View file @
01a10755
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#endif
#endif
DECLARE_int64
(
cinn_self_check_accuracy_num
);
PD_
DECLARE_int64
(
cinn_self_check_accuracy_num
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
...
...
paddle/cinn/hlir/framework/accuracy_checker_test.cc
View file @
01a10755
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
DECLARE_string
(
cinn_self_check_accuracy
);
PD_
DECLARE_string
(
cinn_self_check_accuracy
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
...
...
paddle/cinn/hlir/framework/compile_error.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/utils/enum_string.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
std
::
string
CompileErrorHandler
::
GeneralErrorMessage
()
const
{
std
::
ostringstream
os
;
os
<<
"[CompileError] An error occurred during compilation with the error "
"code: "
<<
utils
::
Enum2String
(
status_
)
<<
std
::
endl
;
os
<<
"(at "
<<
file_
<<
" : "
<<
line_
<<
")"
<<
std
::
endl
;
os
<<
indent_str_
<<
"[Error info] "
<<
this
->
err_msg_
<<
std
::
endl
;
return
os
.
str
();
}
std
::
string
CompileErrorHandler
::
DetailedErrorMessage
()
const
{
std
::
ostringstream
os
;
os
<<
GeneralErrorMessage
();
os
<<
indent_str_
<<
"[Detail info] "
<<
detail_info_
<<
std
::
endl
;
return
os
.
str
();
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/compile_error.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
/**
* This handler is used to deal with the errors during the compilation process
*/
class
CompileErrorHandler
:
public
utils
::
ErrorHandler
{
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit
CompileErrorHandler
(
const
CompilationStatus
&
status
,
const
std
::
string
&
err_msg
,
const
std
::
string
&
detail_info
,
const
char
*
file
,
int
line
)
:
status_
(
status
),
err_msg_
(
err_msg
),
detail_info_
(
detail_info
),
file_
(
file
),
line_
(
line
)
{}
/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std
::
string
GeneralErrorMessage
()
const
;
/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std
::
string
DetailedErrorMessage
()
const
;
CompilationStatus
Status
()
const
{
return
status_
;
}
private:
CompilationStatus
status_
;
std
::
string
err_msg_
;
std
::
string
detail_info_
;
const
char
*
file_
;
int
line_
;
};
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/convert_to_dialect.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include <string>
#include <unordered_map>
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/program.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
std
::
unique_ptr
<::
ir
::
Program
>
ConvertToRuntimeDialect
(
const
hlir
::
framework
::
Program
&
program
)
{
::
ir
::
IrContext
*
ctx
=
::
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
cinn
::
dialect
::
RuntimeDialect
>
();
auto
ir_program
=
std
::
make_unique
<::
ir
::
Program
>
(
ctx
);
std
::
string
jit_op_name
=
dialect
::
JitKernelOp
::
name
();
::
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
jit_op_name
);
auto
&
instrs
=
program
.
GetRunInstructions
();
for
(
auto
&
instr
:
instrs
)
{
std
::
unordered_map
<
std
::
string
,
::
ir
::
Attribute
>
op_attrs
{
{
dialect
::
JitKernelOp
::
kAttrName
,
::
ir
::
PointerAttribute
::
get
(
ctx
,
instr
.
get
())},
};
::
ir
::
Operation
*
cinn_op
=
::
ir
::
Operation
::
Create
({},
op_attrs
,
{},
op_info
);
ir_program
->
block
()
->
push_back
(
cinn_op
);
}
return
std
::
move
(
ir_program
);
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/convert_to_dialect.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
namespace
ir
{
class
Program
;
}
// namespace ir
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
class
Program
;
std
::
unique_ptr
<::
ir
::
Program
>
ConvertToRuntimeDialect
(
const
hlir
::
framework
::
Program
&
program
);
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/graph.cc
View file @
01a10755
...
@@ -18,10 +18,14 @@
...
@@ -18,10 +18,14 @@
#include <sstream>
#include <sstream>
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#endif
#include "paddle/cinn/adt/m_expr.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string
(
cinn_fusion_groups_graphviz_dir
);
PD_
DECLARE_string
(
cinn_fusion_groups_graphviz_dir
);
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
...
@@ -309,15 +313,20 @@ void Graph::VisualizeGroupedGraph(
...
@@ -309,15 +313,20 @@ void Graph::VisualizeGroupedGraph(
}
}
// Dump debug info for each group
// Dump debug info for each group
LOG
(
INFO
)
<<
"Dump graph debug info to: "
V
LOG
(
4
)
<<
"Dump graph debug info to: "
<<
FLAGS_cinn_fusion_groups_graphviz_dir
;
<<
FLAGS_cinn_fusion_groups_graphviz_dir
;
const
auto
&
groups
=
RemoveAccCheckGroups
(
origin_groups
);
const
auto
&
groups
=
RemoveAccCheckGroups
(
origin_groups
);
const
auto
&
group_dots
=
VisualizeGroups
(
groups
,
fetch_var_ids
);
const
auto
&
group_dots
=
VisualizeGroups
(
groups
,
fetch_var_ids
);
for
(
int
idx
=
0
;
idx
<
groups
.
size
();
++
idx
)
{
for
(
int
idx
=
0
;
idx
<
groups
.
size
();
++
idx
)
{
// Create fusion_group_x folder
// Create fusion_group_x folder
int
device_id
=
0
;
#ifdef CINN_WITH_CUDA
cudaGetDevice
(
&
device_id
);
#endif
auto
group_path
=
auto
group_path
=
utils
::
StringFormat
(
"%s/fusion_group_%d"
,
utils
::
StringFormat
(
"%s/
device_%d/
fusion_group_%d"
,
FLAGS_cinn_fusion_groups_graphviz_dir
.
c_str
(),
FLAGS_cinn_fusion_groups_graphviz_dir
.
c_str
(),
device_id
,
idx
);
idx
);
if
(
!
MakeDirectory
(
group_path
,
if
(
!
MakeDirectory
(
group_path
,
S_IRWXU
|
S_IRGRP
|
S_IXGRP
|
S_IROTH
|
S_IXOTH
))
{
S_IRWXU
|
S_IRGRP
|
S_IXGRP
|
S_IROTH
|
S_IXOTH
))
{
...
@@ -468,7 +477,7 @@ std::vector<std::string> Graph::VisualizeGroups(
...
@@ -468,7 +477,7 @@ std::vector<std::string> Graph::VisualizeGroups(
return
dot_vec
;
return
dot_vec
;
}
}
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetInputNodeDatas
()
{
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetInputNodeDatas
()
const
{
std
::
unordered_set
<
NodeData
*>
group_inputs
;
std
::
unordered_set
<
NodeData
*>
group_inputs
;
// count all node's input data
// count all node's input data
...
@@ -498,7 +507,7 @@ std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
...
@@ -498,7 +507,7 @@ std::unordered_set<NodeData*> Graph::Group::GetInputNodeDatas() {
return
group_inputs
;
return
group_inputs
;
}
}
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetOutputNodeDatas
()
{
std
::
unordered_set
<
NodeData
*>
Graph
::
Group
::
GetOutputNodeDatas
()
const
{
std
::
unordered_set
<
NodeData
*>
group_outputs
;
std
::
unordered_set
<
NodeData
*>
group_outputs
;
for
(
auto
node
:
this
->
output_nodes
)
{
for
(
auto
node
:
this
->
output_nodes
)
{
...
...
paddle/cinn/hlir/framework/graph.h
View file @
01a10755
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/node.h"
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
namespace
framework
{
namespace
framework
{
...
@@ -59,6 +60,8 @@ class Graph : public cinn::common::Graph {
...
@@ -59,6 +60,8 @@ class Graph : public cinn::common::Graph {
std
::
vector
<
std
::
vector
<
Node
*>>
groups
;
std
::
vector
<
std
::
vector
<
Node
*>>
groups
;
struct
Group
{
struct
Group
{
Group
()
=
default
;
Group
()
=
default
;
Group
(
const
Group
&
)
=
delete
;
Group
(
Group
&&
)
=
delete
;
explicit
Group
(
const
Graph
*
graph
)
:
graph_
(
graph
)
{}
explicit
Group
(
const
Graph
*
graph
)
:
graph_
(
graph
)
{}
...
@@ -109,7 +112,7 @@ class Graph : public cinn::common::Graph {
...
@@ -109,7 +112,7 @@ class Graph : public cinn::common::Graph {
}
}
};
};
std
::
vector
<
Node
*>
CollectNodes
()
{
std
::
vector
<
Node
*>
CollectNodes
()
const
{
if
(
fused_sub_groups
.
size
())
{
if
(
fused_sub_groups
.
size
())
{
std
::
vector
<
Node
*>
tmp_nodes
;
std
::
vector
<
Node
*>
tmp_nodes
;
for
(
auto
&
group
:
fused_sub_groups
)
{
for
(
auto
&
group
:
fused_sub_groups
)
{
...
@@ -144,8 +147,8 @@ class Graph : public cinn::common::Graph {
...
@@ -144,8 +147,8 @@ class Graph : public cinn::common::Graph {
return
node_set
;
return
node_set
;
}
}
std
::
unordered_set
<
NodeData
*>
GetInputNodeDatas
();
std
::
unordered_set
<
NodeData
*>
GetInputNodeDatas
()
const
;
std
::
unordered_set
<
NodeData
*>
GetOutputNodeDatas
();
std
::
unordered_set
<
NodeData
*>
GetOutputNodeDatas
()
const
;
std
::
string
GetFuncName
()
{
return
"fn_"
+
group_id
+
unique_id
;
}
std
::
string
GetFuncName
()
{
return
"fn_"
+
group_id
+
unique_id
;
}
...
...
paddle/cinn/hlir/framework/graph_compiler.cc
View file @
01a10755
...
@@ -29,8 +29,11 @@
...
@@ -29,8 +29,11 @@
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
namespace
cinn
{
namespace
cinn
{
namespace
hlir
{
namespace
hlir
{
namespace
framework
{
namespace
framework
{
...
@@ -40,90 +43,124 @@ using cinn::common::float16;
...
@@ -40,90 +43,124 @@ using cinn::common::float16;
std
::
unique_ptr
<
Program
>
GraphCompiler
::
Build
(
const
std
::
string
&
code
)
{
std
::
unique_ptr
<
Program
>
GraphCompiler
::
Build
(
const
std
::
string
&
code
)
{
utils
::
RecordEvent
(
"GraphCompiler::Build"
,
utils
::
EventType
::
kGraph
);
utils
::
RecordEvent
(
"GraphCompiler::Build"
,
utils
::
EventType
::
kGraph
);
GraphCompiler
::
CompileOptions
options
;
compilation_context_
.
ApplySourceCode
(
code
);
options
.
attached_code
=
code
;
compilation_context_
.
with_instantiate_variables
=
true
;
options
.
with_instantiate_variables
=
true
;
auto
&&
result
=
Build
(
options
);
return
std
::
move
(
result
.
runtime_program
);
}
void
GraphCompiler
::
CompileOptions
::
Apply
(
auto
&&
result
=
Build
(
&
compilation_context_
);
const
auto_schedule
::
TuningResult
&
tuning_result
)
{
return
result
.
RuntimeProgram
();
// assign options with TuningResult directly
groups
.
assign
(
tuning_result
.
subgraphs
.
begin
(),
tuning_result
.
subgraphs
.
end
());
lowered_funcs
.
assign
(
tuning_result
.
function_groups
.
begin
(),
tuning_result
.
function_groups
.
end
());
}
}
GraphCompiler
::
CompilationResult
GraphCompiler
::
Build
(
CompilationResult
GraphCompiler
::
Build
(
CompilationContext
*
context
)
{
const
GraphCompiler
::
CompileOptions
&
options
,
std
::
unordered_set
<
std
::
string
>&&
fetch_var_ids
,
void
*
stream
)
{
Context
::
Global
().
ResetNameId
();
Context
::
Global
().
ResetNameId
();
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
// write group's information into FLAGS_cinn_fusion_groups_graphviz_dir
graph_
->
VisualizeGroupedGraph
(
fetch_var_ids
.
empty
()
?
fetch_var_ids_
context
->
graph
->
VisualizeGroupedGraph
(
context
->
fetch_var_ids
);
:
fetch_var_ids
);
if
(
options
.
with_instantiate_variables
)
{
if
(
context
->
with_instantiate_variables
)
{
InstantiateVariables
();
InstantiateVariables
(
context
);
}
}
VLOG
(
2
)
<<
"Compile With Parallel Compiler!"
;
VLOG
(
2
)
<<
"Compile With Parallel Compiler!"
;
utils
::
RecordEvent
(
"GraphCompiler CompileResult"
,
utils
::
RecordEvent
(
"GraphCompiler CompileResult"
,
utils
::
EventType
::
kOrdinary
);
utils
::
EventType
::
kOrdinary
);
ParallelCompiler
::
CompileOptions
option
;
option
.
lowered_funcs
=
options
.
lowered_funcs
;
parallel_compiler_
=
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
std
::
make_shared
<
ParallelCompiler
>
(
scope_
,
graph_
,
option
,
target_
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
auto
result
=
(
*
parallel_compiler_
.
get
())();
// Dump compilation result
if
(
context
->
stage
!=
CompilationStage
::
DEFAULT
||
!
result
.
IsSuccess
())
{
backends
::
CompilationInfoDumper
dumper
(
result
);
return
result
;
}
if
(
options
.
remove_unused_variables
)
{
if
(
context
->
remove_unused_variables
)
{
RemoveInvalidVariables
(
result
.
i
nstructions
);
RemoveInvalidVariables
(
context
,
result
.
RuntimeI
nstructions
()
);
}
}
if
(
options
.
with_buffer_handle_instruction_inserted
)
{
if
(
context
->
with_buffer_handle_instruction_inserted
)
{
VLOG
(
3
)
<<
"option.with_buffer_handle_instruction_inserted enable"
;
VLOG
(
3
)
<<
"option.with_buffer_handle_instruction_inserted enable"
;
InsertBufferHandlers
(
&
result
.
instructions
);
InsertBufferHandlers
(
context
,
&
result
.
instructions
_
);
}
}
VLOG
(
2
)
<<
"Compile With Parallel Compiler Done!"
;
VLOG
(
2
)
<<
"Compile With Parallel Compiler Done!"
;
GraphCompiler
::
CompilationResult
compilation_result
;
result
.
SetRuntimeProgram
(
std
::
make_unique
<
Program
>
(
compilation_result
.
runtime_program
.
reset
(
context
->
scope
,
std
::
move
(
result
.
instructions_
)));
new
Program
(
scope_
,
std
::
move
(
result
.
instructions
)));
return
result
;
return
compilation_result
;
}
CompilationResult
GraphCompiler
::
Lowering
()
{
return
Lowering
(
&
compilation_context_
);
}
}
void
GraphCompiler
::
InstantiateVariables
()
{
CompilationResult
GraphCompiler
::
Lowering
(
CompilationContext
*
context
)
{
// Global setting
Context
::
Global
().
ResetNameId
();
// Setting compile options
VLOG
(
2
)
<<
"Compile With Parallel Compiler! But just lowering!"
;
context
->
stage
=
CompilationStage
::
LOWERING
;
// Compile with parallel compiler
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
return
result
;
}
CompilationResult
GraphCompiler
::
CodegenAndJit
()
{
return
CodegenAndJit
(
&
compilation_context_
);
}
CompilationResult
GraphCompiler
::
CodegenAndJit
(
CompilationContext
*
context
)
{
// Global setting
Context
::
Global
().
ResetNameId
();
// Setting compile options
VLOG
(
2
)
<<
"Compile With Parallel Compiler! But just codegen and jit!"
;
context
->
stage
=
CompilationStage
::
CODEGEN_AND_JIT
;
// Compile with parallel compiler
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
return
result
;
}
CompilationResult
GraphCompiler
::
BuildInstruction
()
{
return
BuildInstruction
(
&
compilation_context_
);
}
CompilationResult
GraphCompiler
::
BuildInstruction
(
CompilationContext
*
context
)
{
// Global setting
Context
::
Global
().
ResetNameId
();
// Setting compile options
VLOG
(
2
)
<<
"Compile With Parallel Compiler! But just build instruction!"
;
context
->
stage
=
CompilationStage
::
BUILD_INSTRUCTION
;
// Compile with parallel compiler
parallel_compiler_
=
std
::
make_shared
<
ParallelCompiler
>
(
context
);
CompilationResult
result
=
(
*
parallel_compiler_
.
get
())();
return
result
;
}
void
GraphCompiler
::
InstantiateVariables
(
CompilationContext
*
context
)
{
VLOG
(
3
)
<<
"Instantiate all variables on compile-time"
;
VLOG
(
3
)
<<
"Instantiate all variables on compile-time"
;
utils
::
RecordEvent
(
"GraphCompiler MutableData"
,
utils
::
EventType
::
kOrdinary
);
utils
::
RecordEvent
(
"GraphCompiler MutableData"
,
utils
::
EventType
::
kOrdinary
);
// All variables reside in scope_, so traverse it to instantiate each one
// All variables reside in scope_, so traverse it to instantiate each one
for
(
auto
&
name
:
scope_
->
var_names
())
{
for
(
auto
&
name
:
context
->
scope
->
var_names
())
{
auto
*
var
=
scope_
->
Var
<
Tensor
>
(
std
::
string
({
name
.
data
(),
name
.
size
()}));
auto
*
var
=
context
->
scope
->
Var
<
Tensor
>
(
std
::
string
({
name
.
data
(),
name
.
size
()}));
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
if
(
reuse_vars_map
_
.
count
(
name
))
{
if
(
context
->
reuse_vars_map
.
count
(
name
))
{
auto
src_var_name
=
reuse_vars_map
_
.
at
(
name
);
auto
src_var_name
=
context
->
reuse_vars_map
.
at
(
name
);
auto
*
src_var
=
scope
_
->
Var
<
Tensor
>
(
src_var_name
);
auto
*
src_var
=
context
->
scope
->
Var
<
Tensor
>
(
src_var_name
);
auto
&
src_tensor
=
absl
::
get
<
Tensor
>
(
*
src_var
);
auto
&
src_tensor
=
absl
::
get
<
Tensor
>
(
*
src_var
);
tensor
->
set_buffer
(
src_tensor
->
get_buffer
());
tensor
->
set_buffer
(
src_tensor
->
get_buffer
());
}
else
{
}
else
{
tensor
->
mutable_data
(
target
_
,
tensor
->
type
());
tensor
->
mutable_data
(
context
->
target
,
tensor
->
type
());
}
}
}
}
}
}
void
GraphCompiler
::
RemoveInvalidVariables
(
void
GraphCompiler
::
RemoveInvalidVariables
(
CompilationContext
*
context
,
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
instructions
)
{
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
instructions
)
{
// mark all variables are invalid initially
// mark all variables are invalid initially
utils
::
RecordEvent
(
"GraphCompiler RemoveInvalidVariables"
,
utils
::
RecordEvent
(
"GraphCompiler RemoveInvalidVariables"
,
utils
::
EventType
::
kOrdinary
);
utils
::
EventType
::
kOrdinary
);
std
::
unordered_set
<
std
::
string
>
invalid_variables
;
std
::
unordered_set
<
std
::
string
>
invalid_variables
;
auto
var_names
=
scope
_
->
var_names
();
auto
var_names
=
context
->
scope
->
var_names
();
invalid_variables
.
reserve
(
var_names
.
size
());
invalid_variables
.
reserve
(
var_names
.
size
());
std
::
transform
(
std
::
transform
(
var_names
.
begin
(),
var_names
.
begin
(),
...
@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables(
...
@@ -162,8 +199,8 @@ void GraphCompiler::RemoveInvalidVariables(
<<
" invalid variables to be removed from scope"
;
<<
" invalid variables to be removed from scope"
;
std
::
for_each
(
invalid_variables
.
begin
(),
std
::
for_each
(
invalid_variables
.
begin
(),
invalid_variables
.
end
(),
invalid_variables
.
end
(),
[
this
](
const
std
::
string
&
var_name
)
{
[
context
](
const
std
::
string
&
var_name
)
{
scope
_
->
EraseVar
(
var_name
);
context
->
scope
->
EraseVar
(
var_name
);
VLOG
(
3
)
<<
"Variable("
<<
var_name
<<
") is erased"
;
VLOG
(
3
)
<<
"Variable("
<<
var_name
<<
") is erased"
;
});
});
}
}
...
@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime(
...
@@ -222,6 +259,7 @@ void GraphCompiler::AnalyzeVariableLifeTime(
}
}
void
GraphCompiler
::
InsertBufferHandlers
(
void
GraphCompiler
::
InsertBufferHandlers
(
CompilationContext
*
context
,
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>*
instructions
)
{
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>*
instructions
)
{
utils
::
RecordEvent
(
"GraphCompiler InsertBufferHandlers"
,
utils
::
RecordEvent
(
"GraphCompiler InsertBufferHandlers"
,
utils
::
EventType
::
kOrdinary
);
utils
::
EventType
::
kOrdinary
);
...
@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers(
...
@@ -240,7 +278,7 @@ void GraphCompiler::InsertBufferHandlers(
auto
function_name
=
"malloc_buffer_instruction_"
+
std
::
to_string
(
step
);
auto
function_name
=
"malloc_buffer_instruction_"
+
std
::
to_string
(
step
);
auto
malloc_instr
=
auto
malloc_instr
=
std
::
make_unique
<
Instruction
>
(
common
::
DefaultHostTarget
(),
std
::
make_unique
<
Instruction
>
(
common
::
DefaultHostTarget
(),
scope
_
.
get
(),
context
->
scope
.
get
(),
malloc_var_names
,
malloc_var_names
,
std
::
vector
<
std
::
string
>
({}),
std
::
vector
<
std
::
string
>
({}),
function_name
);
function_name
);
...
@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers(
...
@@ -263,7 +301,7 @@ void GraphCompiler::InsertBufferHandlers(
auto
function_name
=
"free_buffer_instruction_"
+
std
::
to_string
(
step
);
auto
function_name
=
"free_buffer_instruction_"
+
std
::
to_string
(
step
);
auto
free_instr
=
auto
free_instr
=
std
::
make_unique
<
Instruction
>
(
common
::
DefaultHostTarget
(),
std
::
make_unique
<
Instruction
>
(
common
::
DefaultHostTarget
(),
scope
_
.
get
(),
context
->
scope
.
get
(),
std
::
vector
<
std
::
string
>
({}),
std
::
vector
<
std
::
string
>
({}),
free_var_names
,
free_var_names
,
function_name
);
function_name
);
...
@@ -336,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
...
@@ -336,14 +374,17 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
poly
::
StageMap
stages
=
C
.
back
();
poly
::
StageMap
stages
=
C
.
back
();
std
::
string
func_name_prefix
=
"fn_"
;
std
::
string
func_name_prefix
=
"fn_"
;
auto
funcs
=
lang
::
LowerVec
(
func_name_prefix
+
node_id
,
stages
,
ast_gen_ius
::
TensorGroup
tensor_group
=
all_arg_tensors
,
ast_gen_ius
::
ConvertStageMapToTensorGroup
(
stages
);
{},
auto
funcs
=
lang
::
LowerToAstVec
(
{},
func_name_prefix
+
node_id
,
all_arg_tensors
,
&
tensor_group
,
target
);
nullptr
,
target
,
VLOG
(
4
)
<<
"Lower op: "
<<
node_id
<<
", get "
<<
funcs
.
size
()
true
);
<<
" LoweredFunc:
\n
"
;
for
(
auto
fun
:
funcs
)
{
VLOG
(
4
)
<<
fun
;
}
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
std
::
vector
<
common
::
CINNValue
>
schedule_inputs
;
for
(
int
i
=
0
;
i
<
C
.
size
()
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
C
.
size
()
-
1
;
++
i
)
{
...
@@ -390,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
...
@@ -390,7 +431,8 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
optim
::
OptimizeExprGPU
(
&
(
funcs_after_schedule
[
i
]
->
body
));
optim
::
OptimizeExprGPU
(
&
(
funcs_after_schedule
[
i
]
->
body
));
#endif
#endif
auto
temp_buffers
=
lang
::
GetTempBuffers
(
auto
temp_buffers
=
lang
::
GetTempBuffers
(
all_arg_tensors
,
stages
,
funcs_after_schedule
[
i
]
->
body
);
all_arg_tensors
,
tensor_group
,
funcs_after_schedule
[
i
]
->
body
);
funcs_after_schedule
[
i
]
->
temp_bufs
=
temp_buffers
;
funcs_after_schedule
[
i
]
->
temp_bufs
=
temp_buffers
;
funcs_after_schedule
[
i
]
=
funcs_after_schedule
[
i
]
=
ir
::
_LoweredFunc_
::
Make
(
funcs_after_schedule
[
i
]
->
name
,
ir
::
_LoweredFunc_
::
Make
(
funcs_after_schedule
[
i
]
->
name
,
...
...
paddle/cinn/hlir/framework/graph_compiler.h
View file @
01a10755
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
...
@@ -46,48 +47,41 @@ namespace framework {
...
@@ -46,48 +47,41 @@ namespace framework {
*/
*/
class
GraphCompiler
final
{
class
GraphCompiler
final
{
public:
public:
GraphCompiler
(
Target
target
,
GraphCompiler
(
CompilationContext
context
)
:
compilation_context_
(
context
)
{}
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
shared_ptr
<
Graph
>&
graph
)
:
target_
(
std
::
move
(
target
)),
scope_
(
scope
),
graph_
(
graph
)
{}
struct
CompilationResult
{
std
::
unique_ptr
<
Program
>
runtime_program
;
};
struct
CompileOptions
{
std
::
string
attached_code
=
""
;
bool
with_instantiate_variables
=
false
;
bool
with_buffer_handle_instruction_inserted
=
false
;
bool
remove_unused_variables
=
true
;
// nodes group, it may come from the result of op fusion or graph tuning.
// nodes in a group will be built into an Instruction
std
::
vector
<
std
::
shared_ptr
<
Graph
::
Group
>>
groups
;
// corresponding LoweredFuncs of above grouped nodes,
// if it is empty then graph_compiler will generate for them
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
// apply results of auto-tune to compile
void
Apply
(
const
auto_schedule
::
TuningResult
&
tuning_result
);
};
// Compile with a packing option and result, to be extended easily.
// Compile with a packing option and result, to be extended easily.
CompilationResult
Build
(
const
CompileOptions
&
options
,
CompilationResult
Build
(
CompilationContext
*
context
);
std
::
unordered_set
<
std
::
string
>&&
fetch_var_ids
=
{},
void
*
stream
=
nullptr
);
std
::
unique_ptr
<
Program
>
Build
(
const
std
::
string
&
code
=
""
);
std
::
unique_ptr
<
Program
>
Build
(
const
std
::
string
&
code
=
""
);
const
std
::
shared_ptr
<
Scope
>&
GetScope
()
const
{
return
scope_
;
}
CompilationResult
Lowering
();
CompilationResult
Lowering
(
CompilationContext
*
context
);
CompilationResult
CodegenAndJit
();
CompilationResult
CodegenAndJit
(
CompilationContext
*
context
);
CompilationResult
BuildInstruction
();
CompilationResult
BuildInstruction
(
CompilationContext
*
context
);
const
std
::
shared_ptr
<
Scope
>&
GetScope
()
const
{
return
compilation_context_
.
scope
;
}
CompilationContext
&
GetCompilationContext
()
{
return
compilation_context_
;
}
void
SetCompilationContext
(
const
CompilationContext
&
context
)
{
compilation_context_
=
context
;
}
private:
private:
// instantiate all variables on compile time
// instantiate all variables on compile time
void
InstantiateVariables
();
void
InstantiateVariables
(
CompilationContext
*
context
);
// some variables are eliminated by optimized passes(such as OpFusion),
// some variables are eliminated by optimized passes(such as OpFusion),
// we can filter out them according to arguments of the built instructions,
// we can filter out them according to arguments of the built instructions,
// and erase them from the scope to avoid unnecessary buffer allocation
// and erase them from the scope to avoid unnecessary buffer allocation
void
RemoveInvalidVariables
(
void
RemoveInvalidVariables
(
CompilationContext
*
context
,
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
instructions
);
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
instructions
);
// find the first and last instruction where a variable used, and mark the
// find the first and last instruction where a variable used, and mark the
...
@@ -102,21 +96,14 @@ class GraphCompiler final {
...
@@ -102,21 +96,14 @@ class GraphCompiler final {
// firstly used in the next instruction, and insert a buffer free instruction
// firstly used in the next instruction, and insert a buffer free instruction
// applying on variables after no instruction will use them anymore
// applying on variables after no instruction will use them anymore
void
InsertBufferHandlers
(
void
InsertBufferHandlers
(
CompilationContext
*
context
,
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>*
instructions
);
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>*
instructions
);
private:
private:
// parallel compiler
// parallel compiler
std
::
shared_ptr
<
ParallelCompiler
>
parallel_compiler_
;
std
::
shared_ptr
<
ParallelCompiler
>
parallel_compiler_
;
Target
target_
;
CompilationContext
compilation_context_
;
std
::
shared_ptr
<
Graph
>
graph_
;
std
::
shared_ptr
<
Scope
>
scope_
;
// fetch var ids in cinn and the corresponding var nodes will not be fused so
// as to get the result
std
::
unordered_set
<
std
::
string
>
fetch_var_ids_
;
// map dst reuse var to the src var sharing buffer
absl
::
flat_hash_map
<
std
::
string
,
std
::
string
>
reuse_vars_map_
;
CINN_DISALLOW_COPY_AND_ASSIGN
(
GraphCompiler
);
CINN_DISALLOW_COPY_AND_ASSIGN
(
GraphCompiler
);
};
};
...
...
paddle/cinn/hlir/framework/graph_compiler_test.cc
View file @
01a10755
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/op/use_ops.h"
...
@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) {
...
@@ -48,7 +49,8 @@ TEST(GraphCompilerTest, TestRemoveInvaildVariables) {
ASSERT_EQ
(
scope
->
var_names
().
size
(),
6
);
ASSERT_EQ
(
scope
->
var_names
().
size
(),
6
);
EXPECT_NE
(
scope
->
FindVar
(
c
->
id
),
nullptr
);
EXPECT_NE
(
scope
->
FindVar
(
c
->
id
),
nullptr
);
GraphCompiler
gc
(
target
,
scope
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
runtime_program
=
gc
.
Build
();
ASSERT_EQ
(
scope
->
var_names
().
size
(),
3
);
ASSERT_EQ
(
scope
->
var_names
().
size
(),
3
);
EXPECT_EQ
(
scope
->
FindVar
(
c
->
id
),
nullptr
);
EXPECT_EQ
(
scope
->
FindVar
(
c
->
id
),
nullptr
);
...
@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
...
@@ -69,10 +71,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
auto
scope
=
BuildScope
(
target
,
graph
);
Graph
Compil
er
gc
_disable
(
target
,
scope
,
graph
);
Compil
ationContext
context
_disable
(
graph
,
scope
,
target
);
GraphCompiler
::
CompileOptions
options
;
GraphCompiler
gc_disable
(
context_disable
)
;
// disable with_buffer_handle_instruction_inserted: only 1 instruction
// disable with_buffer_handle_instruction_inserted: only 1 instruction
auto
runtime_program_disable
=
gc_disable
.
Build
(
options
).
runtime_program
;
auto
runtime_program_disable
=
gc_disable
.
Build
(
&
context_disable
).
RuntimeProgram
();
ASSERT_EQ
(
runtime_program_disable
->
size
(),
1
);
ASSERT_EQ
(
runtime_program_disable
->
size
(),
1
);
const
auto
&
computation_instr_disable
=
const
auto
&
computation_instr_disable
=
runtime_program_disable
->
GetRunInstructions
().
front
();
runtime_program_disable
->
GetRunInstructions
().
front
();
...
@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
...
@@ -80,9 +83,11 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
// enable with_buffer_handle_instruction_inserted: 3 instructions, 1st ->
// enable with_buffer_handle_instruction_inserted: 3 instructions, 1st ->
// malloc instruction(a, b, d), 2nd -> the real computation
// malloc instruction(a, b, d), 2nd -> the real computation
// instruction(add + relu) and 3rd -> free instruction
// instruction(add + relu) and 3rd -> free instruction
GraphCompiler
gc_enable
(
target
,
scope
,
graph
);
CompilationContext
context_enable
(
graph
,
scope
,
target
);
options
.
with_buffer_handle_instruction_inserted
=
true
;
context_enable
.
with_buffer_handle_instruction_inserted
=
true
;
auto
runtime_program_enable
=
gc_enable
.
Build
(
options
).
runtime_program
;
GraphCompiler
gc_enable
(
context_enable
);
auto
runtime_program_enable
=
gc_enable
.
Build
(
&
context_enable
).
RuntimeProgram
();
const
auto
&
instructions
=
runtime_program_enable
->
GetRunInstructions
();
const
auto
&
instructions
=
runtime_program_enable
->
GetRunInstructions
();
ASSERT_EQ
(
instructions
.
size
(),
3
);
ASSERT_EQ
(
instructions
.
size
(),
3
);
...
@@ -193,7 +198,8 @@ void RunCublas(
...
@@ -193,7 +198,8 @@ void RunCublas(
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
auto
scope
=
BuildScope
(
target
,
graph
);
GraphCompiler
gc
(
target
,
scope
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
auto
exe_program
=
gc
.
Build
();
auto
exe_program
=
gc
.
Build
();
auto
data_a
=
scope
->
GetTensor
(
"A"
);
auto
data_a
=
scope
->
GetTensor
(
"A"
);
...
@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) {
...
@@ -231,6 +237,66 @@ TEST(GraphCompilerTest, TestCublas) {
RunCublas
(
64
,
128
,
128
,
true
,
true
);
RunCublas
(
64
,
128
,
128
,
true
,
true
);
}
}
TEST
(
GraphCompilerTest
,
TestLowering
)
{
frontend
::
NetBuilder
builder
(
"test_lowering_on_graph_compiler"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
program
=
builder
.
Build
();
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
CompilationResult
result
=
gc
.
Lowering
();
ASSERT_EQ
(
result
.
Status
(),
CompilationStatus
::
SUCCESS
);
}
TEST
(
GraphCompilerTest
,
TestCodegenAndJit
)
{
frontend
::
NetBuilder
builder
(
"test_codegen_and_jit_on_graph_compiler"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
program
=
builder
.
Build
();
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
CompilationResult
result
=
gc
.
CodegenAndJit
();
ASSERT_EQ
(
result
.
Status
(),
CompilationStatus
::
SUCCESS
);
}
TEST
(
GraphCompilerTest
,
TestBuildInstruction
)
{
frontend
::
NetBuilder
builder
(
"test_build_instruction_on_graph_compiler"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
auto
target
=
common
::
DefaultNVGPUTarget
();
auto
program
=
builder
.
Build
();
auto
graph
=
Optimize
(
&
program
,
{},
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
CompilationContext
context
(
graph
,
scope
,
target
);
GraphCompiler
gc
(
context
);
CompilationResult
result
=
gc
.
BuildInstruction
();
ASSERT_EQ
(
result
.
Status
(),
CompilationStatus
::
SUCCESS
);
}
#endif
#endif
}
// namespace framework
}
// namespace framework
...
...
paddle/cinn/hlir/framework/graph_compiler_util.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
void
CompilationContext
::
ApplyTuningResult
(
const
auto_schedule
::
TuningResult
&
tuning_result
)
{
// assign options with TuningResult directly
groups
.
assign
(
tuning_result
.
subgraphs
.
begin
(),
tuning_result
.
subgraphs
.
end
());
lowered_funcs
.
assign
(
tuning_result
.
function_groups
.
begin
(),
tuning_result
.
function_groups
.
end
());
}
void
CompilationContext
::
ApplySourceCode
(
const
std
::
string
&
code
)
{
attached_source_code
=
code
;
}
void
CompilationResult
::
InitCompilationResult
(
int
group_size
)
{
size_
=
group_size
;
status_
.
resize
(
group_size
,
CompilationStatus
::
SUCCESS
);
messages_
.
resize
(
group_size
);
for
(
int
idx
=
0
;
idx
<
group_size
;
++
idx
)
{
messages_
[
idx
]
=
"Group Idx: "
+
std
::
to_string
(
idx
)
+
", Compile Success.
\n
"
;
}
lowered_funcs_
.
resize
(
group_size
,
std
::
nullopt
);
source_codes_
.
resize
(
group_size
,
std
::
nullopt
);
source_ptxs_
.
resize
(
group_size
,
std
::
nullopt
);
instructions_
.
resize
(
group_size
);
}
void
CompilationResult
::
SetStatus
(
int
idx
,
const
CompilationStatus
&
status
)
{
if
(
idx
<
status_
.
size
())
{
status_
[
idx
]
=
status
;
}
}
void
CompilationResult
::
SetMessage
(
int
idx
,
const
std
::
string
&
message
)
{
if
(
idx
<
messages_
.
size
())
{
messages_
[
idx
]
=
message
;
}
}
void
CompilationResult
::
SetLoweredFuncs
(
int
idx
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
funcs
)
{
if
(
idx
<
lowered_funcs_
.
size
())
{
lowered_funcs_
[
idx
]
=
funcs
;
}
}
void
CompilationResult
::
SetSourceCode
(
int
idx
,
const
std
::
string
&
source_code
)
{
if
(
idx
<
source_codes_
.
size
())
{
source_codes_
[
idx
]
=
source_code
;
}
}
void
CompilationResult
::
SetSourcePtx
(
int
idx
,
const
std
::
string
&
source_ptx
)
{
if
(
idx
<
source_ptxs_
.
size
())
{
source_ptxs_
[
idx
]
=
source_ptx
;
}
}
void
CompilationResult
::
SetInstruction
(
int
idx
,
std
::
unique_ptr
<
Instruction
>
instruction
)
{
if
(
idx
<
instructions_
.
size
())
{
instructions_
[
idx
]
=
std
::
move
(
instruction
);
}
}
void
CompilationResult
::
SetRuntimeProgram
(
std
::
unique_ptr
<
Program
>
runtime_program
)
{
runtime_program_
=
std
::
move
(
runtime_program
);
}
bool
CompilationResult
::
IsSuccess
()
const
{
for
(
const
CompilationStatus
&
s
:
status_
)
{
if
(
s
!=
CompilationStatus
::
SUCCESS
)
{
return
false
;
}
}
return
true
;
}
CompilationStatus
CompilationResult
::
Status
()
const
{
CompilationStatus
worst_status
=
CompilationStatus
::
SUCCESS
;
for
(
const
CompilationStatus
&
s
:
status_
)
{
if
(
s
<
worst_status
)
{
worst_status
=
s
;
}
}
return
worst_status
;
}
CompilationStatus
CompilationResult
::
Status
(
int
idx
)
const
{
if
(
idx
>=
status_
.
size
())
{
return
CompilationStatus
::
UNKNOWN_FAIL
;
}
return
status_
[
idx
];
}
std
::
string
CompilationResult
::
Message
()
const
{
std
::
string
res
;
for
(
int
idx
=
0
;
idx
<
messages_
.
size
();
++
idx
)
{
res
+=
messages_
[
idx
];
}
return
res
;
}
std
::
string
CompilationResult
::
Message
(
int
idx
)
const
{
if
(
idx
>=
messages_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
return
messages_
[
idx
];
}
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
CompilationResult
::
LoweredFuncs
()
const
{
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
res
(
lowered_funcs_
.
size
());
for
(
int
idx
=
0
;
idx
<
lowered_funcs_
.
size
();
++
idx
)
{
if
(
lowered_funcs_
[
idx
].
has_value
())
{
res
[
idx
]
=
lowered_funcs_
[
idx
].
value
();
}
else
{
std
::
stringstream
ss
;
ss
<<
"LoweredFuncs of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the lower "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
res
;
}
std
::
vector
<
ir
::
LoweredFunc
>
CompilationResult
::
LoweredFuncs
(
int
idx
)
const
{
if
(
idx
>=
lowered_funcs_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
if
(
!
lowered_funcs_
[
idx
].
has_value
())
{
std
::
stringstream
ss
;
ss
<<
"LoweredFuncs of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the lower process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
lowered_funcs_
[
idx
].
value
();
}
std
::
vector
<
std
::
string
>
CompilationResult
::
SourceCodes
()
const
{
std
::
vector
<
std
::
string
>
res
(
source_codes_
.
size
());
for
(
int
idx
=
0
;
idx
<
source_codes_
.
size
();
++
idx
)
{
if
(
source_codes_
[
idx
].
has_value
())
{
res
[
idx
]
=
source_codes_
[
idx
].
value
();
}
else
{
std
::
stringstream
ss
;
ss
<<
"Source Code of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the codegen "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
res
;
}
std
::
string
CompilationResult
::
SourceCode
(
int
idx
)
const
{
if
(
idx
>=
source_codes_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
if
(
!
source_codes_
[
idx
].
has_value
())
{
std
::
stringstream
ss
;
ss
<<
"Source Code of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the codegen "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
source_codes_
[
idx
].
value
();
}
std
::
vector
<
std
::
string
>
CompilationResult
::
SourcePtxs
()
const
{
std
::
vector
<
std
::
string
>
res
(
source_ptxs_
.
size
());
for
(
int
idx
=
0
;
idx
<
source_ptxs_
.
size
();
++
idx
)
{
if
(
source_ptxs_
[
idx
].
has_value
())
{
res
[
idx
]
=
source_ptxs_
[
idx
].
value
();
}
else
{
std
::
stringstream
ss
;
ss
<<
"Source PTX of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the nvrtc compile "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
res
;
}
std
::
string
CompilationResult
::
SourcePtx
(
int
idx
)
const
{
if
(
idx
>=
source_ptxs_
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
lowered_funcs_
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
if
(
!
source_ptxs_
[
idx
].
has_value
())
{
std
::
stringstream
ss
;
ss
<<
"Source PTX of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the nvrtc compile "
"process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
source_ptxs_
[
idx
].
value
();
}
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
CompilationResult
::
RuntimeInstructions
()
const
{
if
(
runtime_program_
!=
nullptr
)
{
return
runtime_program_
->
GetRunInstructions
();
}
for
(
int
idx
=
0
;
idx
<
instructions_
.
size
();
++
idx
)
{
if
(
instructions_
[
idx
]
==
nullptr
)
{
std
::
stringstream
ss
;
ss
<<
"Instruction of group["
<<
idx
<<
"] is not generated.
\n
"
<<
"Some errors may have occurred during or before the build "
"instruction process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
}
return
instructions_
;
}
const
std
::
unique_ptr
<
Instruction
>&
CompilationResult
::
RuntimeInstruction
(
int
idx
)
const
{
const
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>&
insts
=
runtime_program_
?
runtime_program_
->
GetRunInstructions
()
:
instructions_
;
if
(
idx
>=
insts
.
size
())
{
std
::
stringstream
ss
;
ss
<<
"The index("
<<
idx
<<
") is expected to be less than the size of group("
<<
insts
.
size
()
<<
")."
;
CINN_THROW
(
ss
.
str
());
}
return
insts
[
idx
];
}
std
::
unique_ptr
<
Program
>
CompilationResult
::
RuntimeProgram
()
{
if
(
runtime_program_
==
nullptr
)
{
std
::
stringstream
ss
;
ss
<<
"Runtime program is not generated.
\n
"
<<
"Some errors may have occurred during the compilation process.
\n
"
<<
Message
();
CINN_THROW
(
ss
.
str
());
}
return
std
::
move
(
runtime_program_
);
}
}
// namespace framework
}
// namespace hlir
}
// namespace cinn
paddle/cinn/hlir/framework/graph_compiler_util.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/tuning.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/program.h"
#include "paddle/cinn/ir/lowered_func.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
// An enum class used to control the compilation stage.
enum
class
CompilationStage
{
// Fully compiled by default, the following compilation result can be
// obtained: lowered_function, source_code, source_ptx, instruction and
// runtime_program.
DEFAULT
=
0
,
// Just do lowering, we can only get lowered_function from compilation result.
LOWERING
=
1
,
// Stop after codegen and jit, we can get: lowered_function, source_code and
// source_ptx from compilation result.
CODEGEN_AND_JIT
=
2
,
// Stop after build instruction, we can get: lowered_function, source_code,
// source_ptx and runtime_program from compilation result.
BUILD_INSTRUCTION
=
3
,
};
// An enum class used to represent the compilation status.
enum
class
CompilationStatus
{
// An unknown error occurred during compilation.
UNKNOWN_FAIL
=
0
,
// An error occurred during lowering.
LOWERING_FAIL
=
1
,
// An error occurred during codegen and jit.
CODEGEN_JIT_FAIL
=
2
,
// An error occurred during build instruction.
INSTUCTION_FAIL
=
3
,
// An error occurred during build runtime program.
PROGRAM_FAIL
=
4
,
// Compile successfully.
SUCCESS
=
5
,
};
struct
CompilationContext
{
CompilationContext
()
=
default
;
CompilationContext
(
const
std
::
shared_ptr
<
Graph
>&
graph
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
Target
&
target
)
:
graph
(
graph
),
scope
(
scope
),
target
(
target
)
{}
std
::
string
attached_source_code
=
""
;
// Compile options.
bool
with_instantiate_variables
=
false
;
bool
with_buffer_handle_instruction_inserted
=
false
;
bool
remove_unused_variables
=
true
;
// Compile stage, full compile by default.
CompilationStage
stage
=
CompilationStage
::
DEFAULT
;
// Compile target.
Target
target
;
// Computation graph.
std
::
shared_ptr
<
Graph
>
graph
;
// Variable scope
std
::
shared_ptr
<
Scope
>
scope
;
// Fetch var ids in cinn and the corresponding var nodes will not be fused
// so as to get the result.
std
::
unordered_set
<
std
::
string
>
fetch_var_ids
;
// Map dst reuse var to the src var sharing buffer
absl
::
flat_hash_map
<
std
::
string
,
std
::
string
>
reuse_vars_map
;
// Nodes group, it may come from the result of op fusion or graph tuning.
// Nodes in a group will be built into an Instruction.
std
::
vector
<
std
::
shared_ptr
<
Graph
::
Group
>>
groups
;
// Corresponding lowered functions of above grouped nodes,
// if it is empty then graph_compiler will generate for them.
std
::
vector
<
std
::
vector
<
ir
::
LoweredFunc
>>
lowered_funcs
;
// CUDA stream.
void
*
stream
=
nullptr
;
// Set attached source code, if code is not empty, these codes will replace
// the device_module code after SplitCudaAndHostModule.
void
ApplySourceCode
(
const
std
::
string
&
code
);
// Apply results of auto-tune to compile.
// Compilation will start from CompilationStage::CODEGEN_AND_JIT when tuning
// results are applied.
void
ApplyTuningResult
(
const
auto_schedule
::
TuningResult
&
tuning_result
);
};
class
GraphCompiler
;
class
CompilationResult
{
friend
class
GraphCompiler
;
public:
void
InitCompilationResult
(
int
group_size
);
// Setters
void
SetStatus
(
int
idx
,
const
CompilationStatus
&
status
);
void
SetMessage
(
int
idx
,
const
std
::
string
&
message
);
void
SetLoweredFuncs
(
int
idx
,
const
std
::
vector
<
ir
::
LoweredFunc
>&
funcs
);
void
SetSourceCode
(
int
idx
,
const
std
::
string
&
source_code
);
void
SetSourcePtx
(
int
idx
,
const
std
::
string
&
source_ptx
);
void
SetInstruction
(
int
idx
,
std
::
unique_ptr
<
Instruction
>
instruction
);
void
SetRuntimeProgram
(
std
::
unique_ptr
<
Program
>
runtime_program
);
// Getters
bool
IsSuccess
()
const
;