Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
dbe08e9b
Commit
dbe08e9b
authored
Jun 12, 2023
by
yuguo960516yuguo
Browse files
2.4.2
parent
b5499578
Changes
302
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1041 additions
and
137 deletions
+1041
-137
paddle/fluid/framework/dlpack_tensor.h
paddle/fluid/framework/dlpack_tensor.h
+3
-1
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+21
-3
paddle/fluid/framework/infershape_utils.h
paddle/fluid/framework/infershape_utils.h
+2
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
+824
-0
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
+109
-0
paddle/fluid/framework/ir/constant_folding_pass.cc
paddle/fluid/framework/ir/constant_folding_pass.cc
+3
-0
paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc
paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc
+5
-0
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+0
-2
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+1
-0
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+8
-102
paddle/fluid/framework/ir/multi_batch_merge_pass.cc
paddle/fluid/framework/ir/multi_batch_merge_pass.cc
+0
-2
paddle/fluid/framework/naive_executor.cc
paddle/fluid/framework/naive_executor.cc
+1
-0
paddle/fluid/framework/new_executor/standalone_executor_test.cc
.../fluid/framework/new_executor/standalone_executor_test.cc
+2
-1
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+7
-0
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+48
-0
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+1
-2
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+2
-2
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+3
-3
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+0
-19
No files found.
paddle/fluid/framework/dlpack_tensor.h
View file @
dbe08e9b
...
...
@@ -28,7 +28,7 @@ class DLPackTensor {
std
::
remove_reference
<
decltype
(
::
DLTensor
::
shape
[
0
])
>::
type
;
// int64_t
// lanes is only used in CPU to enable vectorization
explicit
DLPackTensor
(
const
Tensor
&
tensor
,
LaneType
lanes
=
1
);
explicit
DLPackTensor
(
const
phi
::
Dense
Tensor
&
tensor
,
LaneType
lanes
=
1
);
inline
operator
const
::
DLTensor
&
()
const
{
return
t_
;
}
...
...
@@ -44,5 +44,7 @@ class DLPackTensor {
ShapeType
shape_
[
DDim
::
kMaxRank
];
};
DLManagedTensor
*
toDLPack
(
const
phi
::
DenseTensor
&
src
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/infershape_utils.cc
View file @
dbe08e9b
...
...
@@ -87,6 +87,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool
IsSelectedRowsInputs
(
const
std
::
string
&
name
)
const
override
{
auto
var_types
=
ctx_
.
GetInputsVarType
(
name
);
return
std
::
all_of
(
var_types
.
begin
(),
var_types
.
end
(),
[](
const
proto
::
VarType
::
Type
&
type
)
{
return
type
==
proto
::
VarType
::
SELECTED_ROWS
;
});
}
bool
IsSelectedRowsInput
(
const
std
::
string
&
name
)
const
override
{
auto
var_type
=
ctx_
.
GetInputVarType
(
name
);
return
var_type
==
proto
::
VarType
::
SELECTED_ROWS
;
...
...
@@ -155,6 +164,16 @@ int64_t CompatMetaTensor::numel() const {
}
}
bool
CompatMetaTensor
::
is_selected_rows
()
const
{
if
(
is_runtime_
)
{
auto
*
var
=
PADDLE_GET_CONST
(
Variable
*
,
var_
);
return
var
->
IsType
<
phi
::
SelectedRows
>
();
}
else
{
auto
*
var
=
PADDLE_GET_CONST
(
VarDesc
*
,
var_
);
return
var
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
;
}
}
bool
CompatMetaTensor
::
is_dense
()
const
{
if
(
is_runtime_
)
{
auto
*
var
=
PADDLE_GET_CONST
(
Variable
*
,
var_
);
...
...
@@ -182,7 +201,7 @@ DDim CompatMetaTensor::dims() const {
if
(
var
->
IsType
<
phi
::
DenseTensor
>
())
{
return
var
->
Get
<
phi
::
DenseTensor
>
().
dims
();
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
return
var
->
Get
<
phi
::
SelectedRows
>
().
d
ims
();
return
var
->
Get
<
phi
::
SelectedRows
>
().
GetCompleteD
ims
();
}
else
if
(
var
->
IsType
<
phi
::
SparseCooTensor
>
())
{
return
var
->
Get
<
phi
::
SparseCooTensor
>
().
dims
();
}
else
if
(
var
->
IsType
<
framework
::
LoDTensorArray
>
())
{
...
...
@@ -260,8 +279,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensorUtils
::
GetMutableMeta
(
tensor
)
->
dims
=
dims
;
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
auto
*
tensor
=
var
->
GetMutable
<
phi
::
SelectedRows
>
()
->
mutable_value
();
phi
::
DenseTensorUtils
::
GetMutableMeta
(
tensor
)
->
dims
=
dims
;
var
->
GetMutable
<
phi
::
SelectedRows
>
()
->
set_height
(
dims
[
0
]);
}
else
if
(
var
->
IsType
<
phi
::
SparseCooTensor
>
())
{
auto
*
tensor
=
var
->
GetMutable
<
phi
::
SparseCooTensor
>
();
phi
::
DenseTensorUtils
::
GetMutableMeta
(
tensor
)
->
dims
=
dims
;
...
...
paddle/fluid/framework/infershape_utils.h
View file @
dbe08e9b
...
...
@@ -59,6 +59,8 @@ class CompatMetaTensor : public phi::MetaTensor {
bool
initialized
()
const
override
{
return
initialized_
;
};
bool
is_selected_rows
()
const
;
bool
is_tensor_array
()
const
;
bool
is_dense
()
const
;
...
...
paddle/fluid/framework/ir/CMakeLists.txt
View file @
dbe08e9b
...
...
@@ -148,6 +148,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library
(
preln_residual_bias_fuse_pass inference
)
pass_library
(
delete_fill_constant_op_pass inference
)
pass_library
(
constant_folding_pass inference
)
pass_library
(
auto_mixed_precision_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
skip_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
0 → 100644
View file @
dbe08e9b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
using
VarType
=
AutoMixedPrecisionPass
::
VarType
;
bool
PhiKernelSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
data_type
,
phi
::
DataLayout
layout
=
phi
::
DataLayout
::
ALL_LAYOUT
)
{
const
auto
&
kernels
=
phi
::
KernelFactory
::
Instance
().
kernels
();
if
(
kernels
.
count
(
op_type
)
==
0
)
{
return
false
;
}
phi
::
KernelKey
kernel_key
(
backend
,
layout
,
data_type
);
return
phi
::
KernelFactory
::
Instance
().
HasKernel
(
op_type
,
kernel_key
);
}
bool
GpuKernelSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
DataType
precision
,
phi
::
DataLayout
layout
=
phi
::
DataLayout
::
ALL_LAYOUT
)
{
auto
phi_op_type
=
phi
::
TransToPhiKernelName
(
op_type
);
bool
support
=
PhiKernelSupportPrecision
(
phi_op_type
,
phi
::
Backend
::
GPU
,
precision
,
layout
);
support
|=
PhiKernelSupportPrecision
(
phi_op_type
,
phi
::
Backend
::
GPUDNN
,
precision
,
layout
);
if
(
!
support
)
{
const
auto
&
all_kernels
=
framework
::
OperatorWithKernel
::
AllOpKernels
();
auto
it
=
all_kernels
.
find
(
op_type
);
if
(
it
!=
all_kernels
.
end
())
{
for
(
const
auto
&
kern_pair
:
it
->
second
)
{
if
(
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
)
&&
kern_pair
.
first
.
data_type_
==
framework
::
TransToProtoVarType
(
precision
))
{
support
=
true
;
break
;
}
}
}
}
return
support
;
}
inline
bool
VarNodeHasDtype
(
Node
*
var_node
)
{
auto
type
=
var_node
->
Var
()
->
GetType
();
return
(
type
==
VarType
::
SELECTED_ROWS
)
||
(
type
==
VarType
::
LOD_TENSOR
)
||
(
type
==
VarType
::
LOD_TENSOR_ARRAY
)
||
(
type
==
VarType
::
STRINGS
)
||
(
type
==
VarType
::
VOCAB
);
}
inline
bool
IsFP32AndFP64
(
VarType
::
Type
type
)
{
return
(
type
==
VarType
::
FP64
)
||
(
type
==
VarType
::
FP32
);
}
inline
bool
IsFP16AndBFP16
(
VarType
::
Type
type
)
{
return
(
type
==
VarType
::
FP16
)
||
(
type
==
VarType
::
BF16
);
}
};
// namespace
void
DoInsertCastOp
(
Graph
*
graph
,
Node
*
var_node
,
Node
*
op_node
,
VarType
::
Type
from_type
,
VarType
::
Type
to_type
,
framework
::
BlockDesc
*
block_desc
,
int
*
suffix
,
std
::
unordered_map
<
Node
*
,
Node
*>*
cache
)
{
if
(
from_type
==
to_type
)
return
;
auto
update_cast_desc
=
[
&
](
framework
::
OpDesc
&
desc
,
const
std
::
string
&
x_name
,
const
std
::
string
&
out_name
,
const
int
in_dtype
,
const
int
out_dtype
)
{
desc
.
SetType
(
"cast"
);
desc
.
SetInput
(
"X"
,
{
x_name
});
desc
.
SetOutput
(
"Out"
,
{
out_name
});
desc
.
SetAttr
(
"in_dtype"
,
in_dtype
);
desc
.
SetAttr
(
"out_dtype"
,
out_dtype
);
desc
.
SetAttr
(
"use_mkldnn"
,
false
);
desc
.
SetAttr
(
"with_quant_attr"
,
false
);
desc
.
Flush
();
};
if
(
cache
->
count
(
var_node
)
==
0
)
{
// insert cast op between var_node and op_node
std
::
string
cast_input_name
=
var_node
->
Var
()
->
Name
();
std
::
string
cast_output_name
=
var_node
->
Var
()
->
Name
()
+
"_cast.tmp_"
+
std
::
to_string
((
*
suffix
)
++
);
framework
::
OpDesc
cast_op_desc
(
block_desc
);
update_cast_desc
(
cast_op_desc
,
cast_input_name
,
cast_output_name
,
static_cast
<
int
>
(
from_type
),
static_cast
<
int
>
(
to_type
));
auto
*
cast_op_node
=
graph
->
CreateOpNode
(
&
cast_op_desc
);
auto
*
cast_output_vardesc
=
block_desc
->
Var
(
cast_output_name
);
cast_output_vardesc
->
SetPersistable
(
false
);
cast_output_vardesc
->
SetDataType
(
to_type
);
cast_output_vardesc
->
SetShape
(
var_node
->
Var
()
->
GetShape
());
auto
*
cast_output_node
=
graph
->
CreateVarNode
(
cast_output_vardesc
);
IR_NODE_LINK_TO
(
cast_op_node
,
cast_output_node
);
(
*
cache
)[
var_node
]
=
cast_output_node
;
}
op_node
->
Op
()
->
Rename
(
var_node
->
Name
(),
cache
->
at
(
var_node
)
->
Name
());
IR_NODE_LINK_TO
(
var_node
,
cache
->
at
(
var_node
)
->
inputs
[
0
]);
IR_NODE_LINK_TO
(
cache
->
at
(
var_node
),
op_node
);
IR_NODE_UNLINK
(
var_node
,
op_node
);
}
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
)
{
bool
support
=
false
;
if
(
black_list
.
count
(
op_type
)
==
0
)
{
if
(
backend
==
phi
::
Backend
::
GPU
)
{
support
=
GpuKernelSupportPrecision
(
op_type
,
precision
);
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"Now, only support backend of GPU."
));
}
}
return
support
;
}
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
// ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
void
AutoMixedPrecisionPass
::
SetDefaultBlacklist
()
const
{
black_list_
.
insert
({
// numerically-dangerous
"exp"
,
"square"
,
"log"
,
"mean"
,
"sum"
,
"cos_sim"
,
"softmax_with_cross_entropy"
,
"sigmoid_cross_entropy_with_logits"
,
"c_softmax_with_cross_entropy"
,
"cross_entropy"
,
"cross_entropy2"
,
// slower than fp32
"conv2d_transpose"
,
// default fp32 can avoid return inf when the sum value large than 65504
"reduce_sum"
,
});
}
void
AutoMixedPrecisionPass
::
Init
(
Graph
*
graph
)
const
{
bool
enable_gpu_mixed
=
Get
<
bool
>
(
"enable_gpu_mixed"
);
if
(
enable_gpu_mixed
)
{
backend_
=
phi
::
Backend
::
GPU
;
}
skip_pass_
=
!
enable_gpu_mixed
;
low_precision_
=
static_cast
<
phi
::
DataType
>
(
Get
<
int
>
(
"mixed_precision_mode"
));
black_list_
=
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"mixed_black_list"
);
SetDefaultBlacklist
();
VLOG
(
4
)
<<
"black_list has "
;
for
(
const
auto
&
name
:
black_list_
)
{
VLOG
(
4
)
<<
" - "
<<
name
;
}
keep_io_types_
=
true
;
if
(
Has
(
"keep_io_types"
))
{
keep_io_types_
=
Get
<
bool
>
(
"keep_io_types"
);
}
auto
graph_size
=
graph
->
SubGraphsSize
();
VLOG
(
4
)
<<
"graph size: "
<<
graph_size
;
subgraphes_
.
resize
(
graph_size
);
all_op_nodes_
.
resize
(
graph_size
);
for
(
size_t
i
=
0
;
i
<
graph_size
;
i
++
)
{
subgraphes_
[
i
]
=
graph
->
GetSubGraph
(
i
);
all_op_nodes_
[
i
]
=
TopologySortOperations
(
*
subgraphes_
[
i
]);
VLOG
(
4
)
<<
"subgraph "
<<
i
<<
" has "
<<
all_op_nodes_
[
i
].
size
()
<<
"op nodes"
;
for
(
auto
*
var_node
:
subgraphes_
[
i
]
->
Nodes
())
{
if
(
!
var_node
->
IsVar
())
continue
;
auto
var_name
=
var_node
->
Var
()
->
Name
();
if
(
real_vars_
.
count
(
var_name
)
==
0
)
{
real_vars_
[
var_name
]
=
var_node
;
VLOG
(
4
)
<<
var_name
<<
" is in graph "
<<
i
;
}
}
}
}
void
AutoMixedPrecisionPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"During the auto_mixed_precision_pass, the graph "
"should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
graph
->
IsMainGraph
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"During the auto_mixed_precision_pass, the graph "
"should be main graph."
));
FusePassBase
::
Init
(
"auto_mixed_precision"
,
graph
);
Init
(
graph
);
VLOG
(
4
)
<<
"Init done"
;
if
(
skip_pass_
)
{
VLOG
(
3
)
<<
"Skip auto_mixed_precision_pass."
;
return
;
}
SetOpUniqueType
();
VLOG
(
4
)
<<
"SetOpUniqueType done"
;
GetOpPrecision
();
VLOG
(
4
)
<<
"GetOpPrecision done"
;
UpdateOpPrecision
();
VLOG
(
4
)
<<
"UpdateOpPrecision done"
;
SetVarPrecision
();
VLOG
(
4
)
<<
"SetVarPrecision done"
;
ConvertWeightsData
();
VLOG
(
4
)
<<
"ConvertWeightsData done"
;
ProcessOpWithDtypeAttr
();
VLOG
(
4
)
<<
"ProcessOpWithDtypeAttr done"
;
InsertCastOp
();
VLOG
(
4
)
<<
"InsertCastOp done"
;
RestoreOpOriginType
();
VLOG
(
4
)
<<
"RestoreOpOriginType done"
;
LOG
(
INFO
)
<<
"The number of ops run at low precision ["
<<
op_run_low_precision_
.
size
()
<<
"/"
<<
op_original_type_
.
size
()
<<
"]"
;
}
void
AutoMixedPrecisionPass
::
SetOpUniqueType
()
const
{
int
suffix
=
0
;
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
std
::
string
unique_type
=
op_type
+
"_"
+
std
::
to_string
(
suffix
++
);
op_original_type_
[
unique_type
]
=
op_type
;
op_node
->
Op
()
->
SetType
(
unique_type
);
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"change op type: "
<<
op_type
<<
" ---> "
<<
unique_type
;
}
}
}
void
AutoMixedPrecisionPass
::
RestoreOpOriginType
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
op_node
->
Op
()
->
SetType
(
GetOpOriginalType
(
op_type
));
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"restore op type: "
<<
op_type
<<
" ---> "
<<
op_node
->
Op
()
->
Type
();
}
}
}
inline
std
::
string
AutoMixedPrecisionPass
::
GetOpOriginalType
(
const
std
::
string
&
op_type
)
const
{
if
(
op_original_type_
.
count
(
op_type
))
{
return
op_original_type_
.
at
(
op_type
);
}
return
op_type
;
}
void
AutoMixedPrecisionPass
::
ProcessOpWithDtypeAttr
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
op_node
->
Op
()
->
HasAttr
(
"in_dtype"
))
{
auto
*
var_node
=
op_node
->
inputs
[
0
];
auto
*
real_var_node
=
real_vars_
[
var_node
->
Var
()
->
Name
()];
if
(
IsFP16AndBFP16
(
real_var_node
->
Var
()
->
GetDataType
()))
{
op_node
->
Op
()
->
SetAttr
(
"in_dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
low_precision_
)));
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"process op with in_dtype attr: "
<<
op_type
<<
" ( "
<<
static_cast
<
int
>
(
real_var_node
->
Var
()
->
GetDataType
())
<<
" --->"
<<
static_cast
<
int
>
(
low_precision_
)
<<
" )"
;
}
}
if
(
op_run_low_precision_
.
count
(
op_type
)
==
0
)
continue
;
if
(
op_node
->
Op
()
->
HasAttr
(
"dtype"
))
{
auto
dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dtype"
);
if
(
IsFP32AndFP64
(
static_cast
<
VarType
::
Type
>
(
dtype
)))
{
op_node
->
Op
()
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
low_precision_
)));
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"process op with dtype attr: "
<<
op_type
<<
" ( "
<<
dtype
<<
" --->"
<<
static_cast
<
int
>
(
low_precision_
)
<<
" )"
;
}
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_dtype"
))
{
auto
out_dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"out_dtype"
);
if
(
IsFP32AndFP64
(
static_cast
<
VarType
::
Type
>
(
out_dtype
)))
{
op_node
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
framework
::
TransToProtoVarType
(
low_precision_
)));
op_node
->
Op
()
->
Flush
();
VLOG
(
4
)
<<
"process op with out_dtype attr: "
<<
op_type
<<
" ( "
<<
out_dtype
<<
" --->"
<<
static_cast
<
int
>
(
low_precision_
)
<<
" )"
;
}
}
}
}
}
void
AutoMixedPrecisionPass
::
GetOpPrecision
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
bool
support_low_precision
=
true
;
if
(
GetOpOriginalType
(
op_type
)
==
"feed"
||
GetOpOriginalType
(
op_type
)
==
"fetch"
)
{
support_low_precision
=
!
keep_io_types_
;
}
else
{
support_low_precision
=
OpSupportPrecision
(
GetOpOriginalType
(
op_type
),
backend_
,
low_precision_
,
black_list_
);
}
if
(
op_node
->
Op
()
->
HasAttr
(
"dtype"
))
{
auto
dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"dtype"
);
support_low_precision
=
support_low_precision
&&
IsFP32AndFP64
(
static_cast
<
VarType
::
Type
>
(
dtype
));
}
else
if
(
op_node
->
Op
()
->
HasAttr
(
"out_dtype"
))
{
auto
out_dtype
=
op_node
->
Op
()
->
GetAttrIfExists
<
int
>
(
"out_dtype"
);
support_low_precision
=
support_low_precision
&&
IsFP32AndFP64
(
static_cast
<
VarType
::
Type
>
(
out_dtype
));
}
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
==
"scale"
)
{
auto
scale
=
op_node
->
Op
()
->
GetAttrIfExists
<
float
>
(
"scale"
);
auto
bias
=
op_node
->
Op
()
->
GetAttrIfExists
<
float
>
(
"bias"
);
if
(
low_precision_
==
phi
::
DataType
::
FLOAT16
)
{
support_low_precision
=
support_low_precision
&&
phi
::
dtype
::
isfinite
(
static_cast
<
phi
::
dtype
::
float16
>
(
scale
))
&&
phi
::
dtype
::
isfinite
(
static_cast
<
phi
::
dtype
::
float16
>
(
bias
));
}
else
if
(
low_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
support_low_precision
=
support_low_precision
&&
phi
::
dtype
::
isfinite
(
static_cast
<
phi
::
dtype
::
bfloat16
>
(
scale
))
&&
phi
::
dtype
::
isfinite
(
static_cast
<
phi
::
dtype
::
bfloat16
>
(
bias
));
}
}
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
if
(
real_in_var_node
->
Var
()
->
Persistable
())
continue
;
support_low_precision
=
support_low_precision
&&
(
real_in_var_node
->
Var
()
->
GetType
()
==
VarType
::
LOD_TENSOR
);
}
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
if
(
real_out_var_node
->
Var
()
->
Persistable
())
continue
;
support_low_precision
=
support_low_precision
&&
(
real_out_var_node
->
Var
()
->
GetType
()
==
VarType
::
LOD_TENSOR
);
}
if
(
support_low_precision
)
{
op_run_low_precision_
.
insert
(
op_type
);
VLOG
(
4
)
<<
"support precision: "
<<
op_type
<<
" run at low precision"
;
}
else
{
VLOG
(
4
)
<<
"support precision: "
<<
op_type
<<
" not run at low precision"
;
}
}
}
}
void
AutoMixedPrecisionPass
::
UpdateOpPrecision
()
const
{
std
::
unordered_set
<
std
::
string
>
vars_should_not_low_precision
;
// var -> the var's all input op
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
Node
*>>
var_input_ops
;
auto
GetVarInputOps
=
[
&
]
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
GetOpOriginalType
(
op_type
)
==
"fetch"
)
continue
;
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
continue
;
for
(
auto
*
var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
var_node
->
IsVar
(),
true
);
if
(
var_node
->
Var
()
->
Persistable
())
continue
;
if
(
!
VarNodeHasDtype
(
var_node
))
continue
;
var_input_ops
[
var_node
->
Var
()
->
Name
()].
push_back
(
op_node
);
VLOG
(
4
)
<<
"var input ops: "
<<
var_node
->
Var
()
->
Name
()
<<
" is output of "
<<
op_type
;
}
// the select_input op's input var should not convert to low precision.
// when op's output var is select_input op's input var, the op should
// not run at low precision.
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
==
"select_input"
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
!
VarNodeHasDtype
(
in_var_node
))
continue
;
vars_should_not_low_precision
.
insert
(
in_var_node
->
Var
()
->
Name
());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run at low precision.
if
(
GetOpOriginalType
(
op_type
)
!=
"feed"
&&
!
GpuKernelSupportPrecision
(
GetOpOriginalType
(
op_type
),
phi
::
DataType
::
FLOAT32
))
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
if
(
out_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
!
VarNodeHasDtype
(
out_var_node
))
continue
;
vars_should_not_low_precision
.
insert
(
out_var_node
->
Var
()
->
Name
());
}
}
}
}
};
GetVarInputOps
();
bool
precision_updated
=
false
;
do
{
precision_updated
=
false
;
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
if
(
op_run_low_precision_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
continue
;
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
if
(
!
VarNodeHasDtype
(
in_var_node
))
continue
;
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
if
(
real_in_var_node
->
Var
()
->
Persistable
())
continue
;
if
(
vars_should_not_low_precision
.
count
(
real_in_var_node
->
Var
()
->
Name
()))
{
op_run_low_precision_
.
erase
(
op_node
->
Op
()
->
Type
());
precision_updated
=
true
;
VLOG
(
4
)
<<
op_node
->
Op
()
->
Type
()
<<
" should not run at low precision."
;
break
;
}
}
if
(
op_run_low_precision_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
continue
;
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
if
(
!
VarNodeHasDtype
(
out_var_node
))
continue
;
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
if
(
real_out_var_node
->
Var
()
->
Persistable
())
continue
;
bool
not_run_low_precision
=
false
;
const
auto
&
input_op_nodes
=
var_input_ops
[
real_out_var_node
->
Var
()
->
Name
()];
if
(
vars_should_not_low_precision
.
count
(
real_out_var_node
->
Var
()
->
Name
()))
{
not_run_low_precision
=
true
;
}
else
{
for
(
auto
*
node
:
input_op_nodes
)
{
if
(
op_run_low_precision_
.
count
(
node
->
Op
()
->
Type
())
==
0
)
{
not_run_low_precision
=
true
;
break
;
}
}
}
if
(
not_run_low_precision
)
{
op_run_low_precision_
.
erase
(
op_node
->
Op
()
->
Type
());
precision_updated
=
true
;
VLOG
(
4
)
<<
op_node
->
Op
()
->
Type
()
<<
" should not run at low precision."
;
break
;
}
}
}
}
}
while
(
precision_updated
);
}
// special ops, its weights should not be low precision.
bool
AutoMixedPrecisionPass
::
InputVarsNotConvert
(
Node
*
op_node
,
const
std
::
string
&
var_name
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"batch_norm"
)
{
auto
vecs
=
op_desc
->
Input
(
"Bias"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"Mean"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"Scale"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"Variance"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
}
else
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"fused_multi_transformer"
)
{
auto
vecs
=
op_desc
->
Input
(
"LnScale"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"LnBias"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"FFNLnScale"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Input
(
"FFNLnBias"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
}
return
false
;
}
bool
AutoMixedPrecisionPass
::
OutputVarsNotConvert
(
Node
*
op_node
,
const
std
::
string
&
var_name
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
// batch_norm's input and output (variance and mean) are the same.
if
(
GetOpOriginalType
(
op_desc
->
Type
())
==
"batch_norm"
)
{
auto
vecs
=
op_desc
->
Output
(
"MeanOut"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Output
(
"VarianceOut"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Output
(
"SavedMean"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
vecs
=
op_desc
->
Output
(
"SavedVariance"
);
if
(
std
::
find
(
vecs
.
begin
(),
vecs
.
end
(),
var_name
)
!=
vecs
.
end
())
{
return
true
;
}
}
return
false
;
}
void
AutoMixedPrecisionPass
::
SetVarPrecision
()
const
{
for
(
const
auto
&
nodes
:
all_op_nodes_
)
{
for
(
auto
*
op_node
:
nodes
)
{
if
(
op_run_low_precision_
.
count
(
op_node
->
Op
()
->
Type
())
==
0
)
{
continue
;
}
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
!=
"feed"
)
{
for
(
auto
*
in_var_node
:
op_node
->
inputs
)
{
CHECK_EQ
(
in_var_node
->
IsVar
(),
true
);
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
auto
in_var_name
=
real_in_var_node
->
Var
()
->
Name
();
if
(
!
IsFP32AndFP64
(
real_in_var_node
->
Var
()
->
GetDataType
()))
continue
;
if
(
!
VarNodeHasDtype
(
real_in_var_node
))
continue
;
if
(
InputVarsNotConvert
(
op_node
,
in_var_name
))
continue
;
if
(
real_in_var_node
->
Var
()
->
Persistable
())
{
real_in_var_node
->
Var
()
->
SetDataType
(
framework
::
TransToProtoVarType
(
low_precision_
));
vars_convert_to_low_precision_
.
insert
(
in_var_name
);
}
}
}
if
(
GetOpOriginalType
(
op_node
->
Op
()
->
Type
())
!=
"fetch"
)
{
for
(
auto
*
out_var_node
:
op_node
->
outputs
)
{
CHECK_EQ
(
out_var_node
->
IsVar
(),
true
);
auto
*
real_out_var_node
=
real_vars_
[
out_var_node
->
Var
()
->
Name
()];
auto
out_var_name
=
real_out_var_node
->
Var
()
->
Name
();
if
(
!
IsFP32AndFP64
(
real_out_var_node
->
Var
()
->
GetDataType
()))
continue
;
if
(
!
VarNodeHasDtype
(
real_out_var_node
))
continue
;
if
(
OutputVarsNotConvert
(
op_node
,
out_var_name
))
continue
;
real_out_var_node
->
Var
()
->
SetDataType
(
framework
::
TransToProtoVarType
(
low_precision_
));
if
(
real_out_var_node
->
Var
()
->
Persistable
())
{
vars_convert_to_low_precision_
.
insert
(
out_var_name
);
}
}
}
}
}
// This code used to precess vars with the same name. Vars with the same
// name should have the same data type.
for
(
auto
*
subgraph
:
subgraphes_
)
{
for
(
auto
*
var_node
:
subgraph
->
Nodes
())
{
if
(
!
var_node
->
IsVar
()
||
!
var_node
->
Var
()
->
Persistable
())
continue
;
if
(
!
VarNodeHasDtype
(
var_node
))
continue
;
auto
var_name
=
var_node
->
Var
()
->
Name
();
if
(
vars_convert_to_low_precision_
.
count
(
var_name
))
{
var_node
->
Var
()
->
SetDataType
(
framework
::
TransToProtoVarType
(
low_precision_
));
}
}
}
}
void
AutoMixedPrecisionPass
::
ConvertWeightsData
()
const
{
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
PreconditionNotMet
(
"During the auto_mixed_precision_pass, the scope "
"should not be null."
));
auto
var_names
=
scope
->
LocalVarNames
();
for
(
const
auto
&
var_name
:
var_names
)
{
if
(
vars_convert_to_low_precision_
.
count
(
var_name
))
{
VLOG
(
4
)
<<
var_name
<<
"'s data type was convert to low precision"
;
auto
*
var
=
scope
->
FindLocalVar
(
var_name
);
CHECK_EQ
(
var
->
IsType
<
phi
::
DenseTensor
>
(),
true
);
auto
*
origin_tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
phi
::
DenseTensor
low_precision_tensor
;
low_precision_tensor
.
Resize
(
origin_tensor
->
dims
());
low_precision_tensor
.
set_type
(
low_precision_
);
if
(
low_precision_
==
phi
::
DataType
::
FLOAT16
)
{
auto
*
low_precision_data
=
low_precision_tensor
.
mutable_data
<
phi
::
dtype
::
float16
>
(
phi
::
CPUPlace
{});
for
(
int64_t
i
=
0
;
i
<
origin_tensor
->
numel
();
i
++
)
{
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT64
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
double
>
();
low_precision_data
[
i
]
=
static_cast
<
phi
::
dtype
::
float16
>
(
origin_data
[
i
]);
}
else
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
float
>
();
low_precision_data
[
i
]
=
static_cast
<
phi
::
dtype
::
float16
>
(
origin_data
[
i
]);
}
}
}
else
if
(
low_precision_
==
phi
::
DataType
::
BFLOAT16
)
{
auto
*
low_precision_data
=
low_precision_tensor
.
mutable_data
<
phi
::
dtype
::
bfloat16
>
(
phi
::
CPUPlace
{});
for
(
int64_t
i
=
0
;
i
<
origin_tensor
->
numel
();
i
++
)
{
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT64
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
double
>
();
low_precision_data
[
i
]
=
static_cast
<
phi
::
dtype
::
bfloat16
>
(
origin_data
[
i
]);
}
else
if
(
origin_tensor
->
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
auto
*
origin_data
=
origin_tensor
->
data
<
float
>
();
low_precision_data
[
i
]
=
static_cast
<
phi
::
dtype
::
bfloat16
>
(
origin_data
[
i
]);
}
}
}
origin_tensor
->
clear
();
paddle
::
framework
::
TensorCopySync
(
low_precision_tensor
,
phi
::
CPUPlace
{},
origin_tensor
);
}
}
}
void
AutoMixedPrecisionPass
::
InsertCastOp
()
const
{
int
suffix
=
0
;
std
::
unordered_map
<
Node
*
,
Node
*>
cache
;
for
(
size_t
i
=
0
;
i
<
all_op_nodes_
.
size
();
i
++
)
{
auto
*
block_desc
=
all_op_nodes_
[
i
][
0
]
->
Op
()
->
Block
();
CHECK_NOTNULL
(
block_desc
);
for
(
auto
*
op_node
:
all_op_nodes_
[
i
])
{
auto
op_type
=
op_node
->
Op
()
->
Type
();
if
(
GetOpOriginalType
(
op_type
)
==
"feed"
)
continue
;
if
(
op_node
->
Op
()
->
HasAttr
(
"sub_block"
))
continue
;
VLOG
(
4
)
<<
"process op: "
<<
op_type
<<
" run low precision: "
<<
op_run_low_precision_
.
count
(
op_type
);
auto
inputs
=
op_node
->
inputs
;
for
(
auto
*
in_var_node
:
inputs
)
{
if
(
!
in_var_node
->
IsVar
())
continue
;
if
(
!
VarNodeHasDtype
(
in_var_node
))
continue
;
if
(
in_var_node
->
Var
()
->
Persistable
())
continue
;
auto
*
real_in_var_node
=
real_vars_
[
in_var_node
->
Var
()
->
Name
()];
auto
in_var_type
=
real_in_var_node
->
Var
()
->
GetDataType
();
VLOG
(
4
)
<<
"process var: "
<<
real_in_var_node
->
Var
()
->
Name
()
<<
" with type "
<<
in_var_type
;
if
(
IsFP32AndFP64
(
in_var_type
)
&&
op_run_low_precision_
.
count
(
op_type
))
{
auto
to_type
=
framework
::
TransToProtoVarType
(
low_precision_
);
auto
*
prev_op
=
in_var_node
->
inputs
.
empty
()
?
nullptr
:
in_var_node
->
inputs
[
0
];
if
(
prev_op
&&
GetOpOriginalType
(
prev_op
->
Op
()
->
Type
())
==
"cast"
)
{
in_var_node
->
Var
()
->
SetDataType
(
to_type
);
prev_op
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
to_type
));
prev_op
->
Op
()
->
Flush
();
}
else
{
DoInsertCastOp
(
subgraphes_
[
i
],
in_var_node
,
op_node
,
in_var_type
,
to_type
,
block_desc
,
&
suffix
,
&
cache
);
}
}
else
if
(
IsFP16AndBFP16
(
in_var_type
)
&&
op_run_low_precision_
.
count
(
op_type
)
==
0
)
{
auto
to_type
=
VarType
::
FP32
;
auto
*
prev_op
=
in_var_node
->
inputs
.
empty
()
?
nullptr
:
in_var_node
->
inputs
[
0
];
if
(
prev_op
&&
GetOpOriginalType
(
prev_op
->
Op
()
->
Type
())
==
"cast"
)
{
in_var_node
->
Var
()
->
SetDataType
(
to_type
);
prev_op
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
to_type
));
prev_op
->
Op
()
->
Flush
();
}
else
{
DoInsertCastOp
(
subgraphes_
[
i
],
in_var_node
,
op_node
,
in_var_type
,
to_type
,
block_desc
,
&
suffix
,
&
cache
);
}
}
}
// Special op.
// fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
// have same name.
if
(
GetOpOriginalType
(
op_type
)
==
"fused_multi_transformer"
)
{
auto
cache_kv_inputs
=
op_node
->
Op
()
->
Input
(
"CacheKV"
);
auto
cache_kv_outputs
=
op_node
->
Op
()
->
Output
(
"CacheKVOut"
);
CHECK_EQ
(
cache_kv_inputs
.
size
(),
cache_kv_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
cache_kv_inputs
.
size
();
++
i
)
{
op_node
->
Op
()
->
RenameOutput
(
cache_kv_outputs
[
i
],
cache_kv_inputs
[
i
]);
}
}
}
}
VLOG
(
4
)
<<
"insert number of cast op: "
<<
cache
.
size
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
auto_mixed_precision_pass
,
paddle
::
framework
::
ir
::
AutoMixedPrecisionPass
);
paddle/fluid/framework/ir/auto_mixed_precision_pass.h
0 → 100644
View file @
dbe08e9b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
AutoMixedPrecisionPass
:
public
FusePassBase
{
public:
using
VarType
=
framework
::
proto
::
VarType
;
public:
AutoMixedPrecisionPass
()
=
default
;
~
AutoMixedPrecisionPass
()
=
default
;
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
private:
void
Init
(
Graph
*
graph
)
const
;
void
SetDefaultBlacklist
()
const
;
void
SetOpUniqueType
()
const
;
void
RestoreOpOriginType
()
const
;
inline
std
::
string
GetOpOriginalType
(
const
std
::
string
&
op_type
)
const
;
void
GetOpPrecision
()
const
;
void
UpdateOpPrecision
()
const
;
void
InsertCastOp
()
const
;
void
ProcessOpWithDtypeAttr
()
const
;
bool
InputVarsNotConvert
(
Node
*
op_node
,
const
std
::
string
&
var_name
)
const
;
bool
OutputVarsNotConvert
(
Node
*
op_node
,
const
std
::
string
&
var_name
)
const
;
void
SetVarPrecision
()
const
;
void
ConvertWeightsData
()
const
;
private:
mutable
bool
skip_pass_
{
false
};
mutable
bool
keep_io_types_
{
false
};
// float16 or bfloat16 now
mutable
phi
::
DataType
low_precision_
{
phi
::
DataType
::
FLOAT16
};
mutable
phi
::
Backend
backend_
{
phi
::
Backend
::
GPU
};
mutable
std
::
unordered_set
<
std
::
string
>
black_list_
;
// subgraph id -> pointer to subgraph
mutable
std
::
vector
<
Graph
*>
subgraphes_
;
// var name -> real var node
mutable
std
::
unordered_map
<
std
::
string
,
Node
*>
real_vars_
;
// subgraph id -> all op nodes in subgraph
mutable
std
::
vector
<
std
::
vector
<
Node
*>>
all_op_nodes_
;
// op's unique type -> the op's origin type
mutable
std
::
unordered_map
<
std
::
string
,
std
::
string
>
op_original_type_
;
// op's unique type -> whether the op run at low precision
mutable
std
::
unordered_set
<
std
::
string
>
op_run_low_precision_
;
mutable
std
::
unordered_set
<
std
::
string
>
vars_convert_to_low_precision_
;
};
bool
OpSupportPrecision
(
const
std
::
string
&
op_type
,
phi
::
Backend
backend
,
phi
::
DataType
precision
,
const
std
::
unordered_set
<
std
::
string
>&
black_list
);
void
DoInsertCastOp
(
Graph
*
graph
,
Node
*
var_node
,
Node
*
op_node
,
proto
::
VarType
::
Type
from_type
,
proto
::
VarType
::
Type
to_type
,
framework
::
BlockDesc
*
block_desc
,
int
*
suffix
,
std
::
unordered_map
<
Node
*
,
Node
*>*
cache
);
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/constant_folding_pass.cc
View file @
dbe08e9b
...
...
@@ -142,6 +142,9 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
}
out_desc
->
SetShape
(
out_shape
);
out_desc
->
SetPersistable
(
true
);
auto
*
var_desc_out
=
op_node
->
Op
()
->
Block
()
->
Var
(
out_name
);
var_desc_out
->
SetShape
(
out_shape
);
var_desc_out
->
SetPersistable
(
true
);
auto
*
global_out_tensor
=
scope
->
Var
(
out_name
)
->
GetMutable
<
LoDTensor
>
();
*
global_out_tensor
=
*
local_out_tensor
;
}
...
...
paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc
View file @
dbe08e9b
...
...
@@ -29,6 +29,11 @@ void FillConstData(LoDTensor* out_t, T value) {
}
void
DeleteFillConstantOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
bool
with_dynamic_shape
=
Get
<
bool
>
(
"with_dynamic_shape"
);
// Not support
if
(
with_dynamic_shape
)
{
return
;
}
FusePassBase
::
Init
(
"delete_fill_constant_op_pass"
,
graph
);
GraphPatternDetector
detector
;
auto
fill_constant_op
=
...
...
paddle/fluid/framework/ir/graph.cc
View file @
dbe08e9b
...
...
@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program,
}
}
else
{
auto
var_nodes
=
InitFromProgram
(
program_
,
start_op_index
,
end_op_index
);
ResolveHazard
(
var_nodes
);
}
}
...
...
@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block,
const
int64_t
end_op_index
)
:
main_graph_
(
main_graph
)
{
auto
var_nodes
=
InitFromBlock
(
block
,
start_op_index
,
end_op_index
);
ResolveHazard
(
var_nodes
);
}
// TODO(levi): delete this interface after when we can convert all
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
View file @
dbe08e9b
...
...
@@ -1045,6 +1045,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
PDNode
*
patterns
::
Squeeze2Transpose2
::
operator
()()
{
auto
*
squeeze2_op_in
=
pattern
->
NewNode
(
squeeze2_op_in_repr
())
->
AsInput
()
->
assert_has_n_outputs
(
1
)
->
assert_is_op_input
(
"squeeze2"
,
"X"
);
auto
*
squeeze2_op
=
pattern
->
NewNode
(
squeeze2_op_repr
())
->
assert_is_op
(
"squeeze2"
)
...
...
paddle/fluid/framework/ir/graph_test.cc
View file @
dbe08e9b
...
...
@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) {
ASSERT_EQ
(
nodes
.
size
(),
5UL
);
}
TEST
(
GraphTest
,
WriteAfterRead
)
{
// void Test() {
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
op
->
SetAttr
(
"op_role"
,
1
);
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"dummy"
);
op
->
SetInput
(
"X"
,
{
"c"
});
op
->
SetOutput
(
"Out"
,
{
"a"
});
op
->
SetAttr
(
"op_role"
,
1
);
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
ir
::
Node
*
control_dep1
=
nullptr
;
ir
::
Node
*
control_dep2
=
nullptr
;
for
(
ir
::
Node
*
n
:
g
->
Nodes
())
{
if
(
n
->
Name
()
==
"sum"
)
{
ASSERT_EQ
(
n
->
outputs
[
0
]
->
Name
(),
"b"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
outputs
[
1
]));
control_dep1
=
n
->
outputs
[
1
];
ASSERT_EQ
(
n
->
outputs
.
size
(),
2UL
);
}
if
(
n
->
Name
()
==
"dummy"
)
{
ASSERT_EQ
(
n
->
inputs
[
0
]
->
Name
(),
"c"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
inputs
[
1
]));
control_dep2
=
n
->
inputs
[
1
];
ASSERT_EQ
(
n
->
inputs
.
size
(),
2UL
);
}
}
ASSERT_EQ
(
control_dep1
,
control_dep2
);
}
TEST
(
GraphTest
,
WriteAfterWrite
)
{
// void Test() {
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"a"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
op
->
SetAttr
(
"op_role"
,
1
);
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"dummy"
);
op
->
SetInput
(
"X"
,
{
"c"
});
op
->
SetOutput
(
"Out"
,
{
"b"
});
op
->
SetAttr
(
"op_role"
,
1
);
prog
.
MutableBlock
(
0
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
0
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
ir
::
Node
*
control_dep1
=
nullptr
;
ir
::
Node
*
control_dep2
=
nullptr
;
for
(
ir
::
Node
*
n
:
g
->
Nodes
())
{
if
(
n
->
Name
()
==
"sum"
)
{
ASSERT_EQ
(
n
->
outputs
[
0
]
->
Name
(),
"b"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
outputs
[
1
]));
ASSERT_EQ
(
n
->
outputs
.
size
(),
2UL
);
control_dep1
=
n
->
outputs
[
1
];
}
if
(
n
->
Name
()
==
"dummy"
)
{
ASSERT_EQ
(
n
->
inputs
[
0
]
->
Name
(),
"c"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
inputs
[
1
]));
control_dep2
=
n
->
inputs
[
1
];
ASSERT_EQ
(
n
->
inputs
.
size
(),
2UL
);
}
}
ASSERT_NE
(
control_dep1
,
nullptr
);
ASSERT_NE
(
control_dep2
,
nullptr
);
ASSERT_EQ
(
control_dep1
,
control_dep2
);
}
TEST
(
GraphTest
,
TestException
)
{
ProgramDesc
prog
;
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
...
...
@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) {
op
=
prog
.
MutableBlock
(
1
)
->
AppendOp
();
op
->
SetType
(
"dummy"
);
op
->
SetInput
(
"X"
,
{
"c"
});
op
->
SetOutput
(
"Out"
,
{
"
a
"
});
op
->
SetOutput
(
"Out"
,
{
"
d
"
});
op
->
SetAttr
(
"op_role"
,
1
);
prog
.
MutableBlock
(
1
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
1
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
1
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
1
)
->
Var
(
"d"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
// Set contents in block_2.
op
=
prog
.
MutableBlock
(
2
)
->
AppendOp
();
...
...
@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) {
op
=
prog
.
MutableBlock
(
2
)
->
AppendOp
();
op
->
SetType
(
"dummy"
);
op
->
SetInput
(
"X"
,
{
"c"
});
op
->
SetOutput
(
"Out"
,
{
"
b
"
});
op
->
SetOutput
(
"Out"
,
{
"
d
"
});
op
->
SetAttr
(
"op_role"
,
1
);
prog
.
MutableBlock
(
2
)
->
Var
(
"a"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
2
)
->
Var
(
"b"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
2
)
->
Var
(
"c"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
prog
.
MutableBlock
(
1
)
->
Var
(
"d"
)
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
...
...
@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1.
const
ir
::
Graph
*
g1
=
g
->
GetSubGraph
(
1
);
ir
::
Node
*
control_dep1
=
nullptr
;
ir
::
Node
*
control_dep2
=
nullptr
;
for
(
ir
::
Node
*
n
:
g1
->
Nodes
())
{
if
(
n
->
Name
()
==
"sum"
)
{
ASSERT_EQ
(
n
->
outputs
[
0
]
->
Name
(),
"b"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
outputs
[
1
]));
control_dep1
=
n
->
outputs
[
1
];
ASSERT_EQ
(
n
->
outputs
.
size
(),
2UL
);
ASSERT_EQ
(
n
->
outputs
.
size
(),
1UL
);
}
if
(
n
->
Name
()
==
"dummy"
)
{
ASSERT_EQ
(
n
->
inputs
[
0
]
->
Name
(),
"c"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
inputs
[
1
]));
control_dep2
=
n
->
inputs
[
1
];
ASSERT_EQ
(
n
->
inputs
.
size
(),
2UL
);
ASSERT_EQ
(
n
->
inputs
.
size
(),
1UL
);
}
}
ASSERT_EQ
(
control_dep1
,
control_dep2
);
// Check contents in sub_graph_2.
const
ir
::
Graph
*
g2
=
g
->
GetSubGraph
(
2
);
control_dep1
=
nullptr
;
control_dep2
=
nullptr
;
for
(
ir
::
Node
*
n
:
g2
->
Nodes
())
{
if
(
n
->
Name
()
==
"sum"
)
{
ASSERT_EQ
(
n
->
outputs
[
0
]
->
Name
(),
"b"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
outputs
[
1
]));
ASSERT_EQ
(
n
->
outputs
.
size
(),
2UL
);
control_dep1
=
n
->
outputs
[
1
];
ASSERT_EQ
(
n
->
outputs
.
size
(),
1UL
);
}
if
(
n
->
Name
()
==
"dummy"
)
{
ASSERT_EQ
(
n
->
inputs
[
0
]
->
Name
(),
"c"
);
ASSERT_TRUE
(
ir
::
IsControlDepVar
(
*
n
->
inputs
[
1
]));
control_dep2
=
n
->
inputs
[
1
];
ASSERT_EQ
(
n
->
inputs
.
size
(),
2UL
);
ASSERT_EQ
(
n
->
inputs
.
size
(),
1UL
);
}
}
ASSERT_NE
(
control_dep1
,
nullptr
);
ASSERT_NE
(
control_dep2
,
nullptr
);
ASSERT_EQ
(
control_dep1
,
control_dep2
);
// Step3: Clone graph.
std
::
shared_ptr
<
ir
::
Graph
>
clone_g
=
g
->
Clone
();
...
...
paddle/fluid/framework/ir/multi_batch_merge_pass.cc
View file @
dbe08e9b
...
...
@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
copy_node
(
node
);
}
}
result
.
ResolveHazard
(
created
);
}
}
// namespace ir
...
...
paddle/fluid/framework/naive_executor.cc
View file @
dbe08e9b
...
...
@@ -183,5 +183,6 @@ void NaiveExecutor::ResetTrtOps(int num) {
}
#endif
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/standalone_executor_test.cc
View file @
dbe08e9b
...
...
@@ -50,7 +50,7 @@ USE_OP_ITSELF(concat_grad);
USE_OP_ITSELF
(
elementwise_mul_grad
);
USE_OP_ITSELF
(
sigmoid_grad
);
USE_OP_ITSELF
(
tanh_grad
);
USE_OP
(
sum
);
USE_OP
_ITSELF
(
sum
);
USE_OP_ITSELF
(
slice_grad
);
USE_OP_ITSELF
(
lookup_table_grad
);
USE_OP_ITSELF
(
sqrt
);
...
...
@@ -101,6 +101,7 @@ PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL
(
cross_entropy_with_softmax
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
cross_entropy_with_softmax_grad
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
sqrt
,
GPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
add_n
,
GPU
,
ALL_LAYOUT
);
namespace
paddle
{
namespace
framework
{
...
...
paddle/fluid/framework/operator.h
View file @
dbe08e9b
...
...
@@ -512,6 +512,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool
IsSelectedRowsInputs
(
const
std
::
string
&
name
)
const
override
{
auto
vars
=
ctx_
.
MultiInputVar
(
name
);
return
std
::
all_of
(
vars
.
begin
(),
vars
.
end
(),
[](
const
Variable
*
var
)
{
return
var
->
IsType
<
phi
::
SelectedRows
>
();
});
}
bool
IsSelectedRowsInput
(
const
std
::
string
&
name
)
const
override
{
const
auto
*
var
=
ctx_
.
InputVar
(
name
);
return
var
->
IsType
<
phi
::
SelectedRows
>
();
...
...
paddle/fluid/imperative/prepared_operator.cc
View file @
dbe08e9b
...
...
@@ -146,6 +146,48 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
kernel_signature_
(
std
::
move
(
kernel_signature
)),
phi_kernel_
(
phi_kernel
)
{}
#ifdef PADDLE_WITH_MLU
static
void
tokenize
(
const
std
::
string
&
ops
,
char
delim
,
std
::
unordered_set
<
std
::
string
>*
op_set
)
{
std
::
string
::
size_type
beg
=
0
;
for
(
uint64_t
end
=
0
;
(
end
=
ops
.
find
(
delim
,
end
))
!=
std
::
string
::
npos
;
++
end
)
{
op_set
->
insert
(
ops
.
substr
(
beg
,
end
-
beg
));
beg
=
end
+
1
;
}
op_set
->
insert
(
ops
.
substr
(
beg
));
}
static
bool
is_in_mlu_black_list
(
const
std
::
string
&
op_name
)
{
static
bool
inited
=
false
;
static
std
::
unordered_set
<
std
::
string
>
mlu_black_list
;
static
std
::
mutex
s_mtx
;
if
(
!
inited
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
s_mtx
);
if
(
!
inited
)
{
if
(
std
::
getenv
(
"MLU_BLACK_LIST"
)
!=
nullptr
)
{
std
::
string
ops
(
std
::
getenv
(
"MLU_BLACK_LIST"
));
tokenize
(
ops
,
','
,
&
mlu_black_list
);
}
inited
=
true
;
VLOG
(
3
)
<<
"MLU Black List: "
;
for
(
auto
iter
=
mlu_black_list
.
begin
();
iter
!=
mlu_black_list
.
end
();
++
iter
)
{
VLOG
(
3
)
<<
*
iter
<<
" "
;
}
}
}
if
(
mlu_black_list
.
find
(
op_name
)
!=
mlu_black_list
.
end
())
{
return
true
;
}
return
false
;
}
#endif
template
<
typename
VarType
>
PreparedOp
PrepareImpl
(
const
NameVarMap
<
VarType
>&
ins
,
...
...
@@ -194,6 +236,12 @@ PreparedOp PrepareImpl(
#endif
#ifdef PADDLE_WITH_MLU
if
(
is_in_mlu_black_list
(
op
.
Type
()))
{
expected_kernel_key
.
place_
=
platform
::
CPUPlace
();
}
#endif
bool
has_phi_kernel
=
false
;
const
auto
*
arg_map_fn
=
phi_op_utils_map
.
GetArgumentMappingFn
(
op
.
Type
());
...
...
paddle/fluid/inference/analysis/analyzer.cc
View file @
dbe08e9b
...
...
@@ -38,8 +38,7 @@ void Analyzer::RunAnalysis(Argument *argument) {
if
(
!
disable_logs
)
{
string
::
PrettyLogH1
(
"--- Running analysis [%s]"
,
pass
);
}
if
(
!
argument
->
enable_analysis_optim
()
&&
pass
==
"ir_analysis_pass"
)
continue
;
if
(
!
argument
->
enable_ir_optim
()
&&
pass
==
"ir_analysis_pass"
)
continue
;
auto
*
ptr
=
PassRegistry
::
Global
().
Retreive
(
pass
);
PADDLE_ENFORCE_NOT_NULL
(
ptr
,
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
View file @
dbe08e9b
...
...
@@ -31,7 +31,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
Argument
argument
;
argument
.
SetDisableLogs
(
false
);
argument
.
SetModelDir
(
FLAGS_inference_model_dir
);
argument
.
SetEnable
Analysis
Optim
(
false
);
argument
.
SetEnable
Ir
Optim
(
false
);
argument
.
SetUseGPU
(
false
);
argument
.
SetAnalysisPasses
({
"ir_graph_build_pass"
,
"ir_analysis_pass"
,
...
...
@@ -44,7 +44,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST
(
Analyzer
,
analysis_with_tensorrt
)
{
Argument
argument
;
argument
.
SetDisableLogs
(
false
);
argument
.
SetEnable
Analysis
Optim
(
false
);
argument
.
SetEnable
Ir
Optim
(
false
);
argument
.
SetTensorRtMaxBatchSize
(
3
);
argument
.
SetTensorRtWorkspaceSize
(
1
<<
20
);
argument
.
SetModelDir
(
FLAGS_inference_model_dir
);
...
...
paddle/fluid/inference/analysis/argument.h
View file @
dbe08e9b
...
...
@@ -42,8 +42,6 @@ namespace paddle {
namespace
inference
{
namespace
analysis
{
using
framework
::
ir
::
Graph
;
#ifdef PADDLE_WITH_MKLDNN
using
VarQuantScale
=
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
bool
,
framework
::
LoDTensor
>>
;
...
...
@@ -148,7 +146,7 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
model_params_path
,
ModelParamsPath
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
model_from_memory
,
ModelFromMemory
,
bool
);
DECL_ARGUMENT_FIELD
(
optim_cache_dir
,
OptimCacheDir
,
std
::
string
);
DECL_ARGUMENT_FIELD
(
enable_
analysis
_optim
,
Enable
Analysis
Optim
,
bool
);
DECL_ARGUMENT_FIELD
(
enable_
ir
_optim
,
Enable
Ir
Optim
,
bool
);
// For JITLayer
DECL_ARGUMENT_FIELD
(
skip_load_params
,
SkipLoadParams
,
bool
);
...
...
@@ -362,6 +360,8 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
mixed_black_list
,
MixedBlackList
,
std
::
unordered_set
<
std
::
string
>
);
DECL_ARGUMENT_FIELD
(
enable_gpu_mixed
,
EnableGPUMixed
,
bool
);
DECL_ARGUMENT_FIELD
(
mixed_precision_mode
,
MixedPrecisionMode
,
int
);
private:
std
::
unordered_set
<
std
::
string
>
valid_fields_
;
...
...
paddle/fluid/inference/analysis/helper.h
View file @
dbe08e9b
...
...
@@ -153,25 +153,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return
*
var
->
GetMutable
<
T
>
();
}
static
framework
::
proto
::
ProgramDesc
LoadProgramDesc
(
const
std
::
string
&
model_path
)
{
std
::
ifstream
fin
(
model_path
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
fin
.
is_open
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file exists"
,
model_path
));
fin
.
seekg
(
0
,
std
::
ios
::
end
);
std
::
string
buffer
(
fin
.
tellg
(),
' '
);
fin
.
seekg
(
0
,
std
::
ios
::
beg
);
fin
.
read
(
&
buffer
[
0
],
buffer
.
size
());
fin
.
close
();
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
buffer
);
return
program_desc
;
}
static
bool
FileExists
(
const
std
::
string
&
filepath
)
{
std
::
ifstream
file
(
filepath
);
bool
exists
=
file
.
is_open
();
...
...
Prev
1
2
3
4
5
6
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment