Commit dbe08e9b authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.2

parent b5499578
......@@ -28,7 +28,7 @@ class DLPackTensor {
std::remove_reference<decltype(::DLTensor::shape[0])>::type; // int64_t
// lanes is only used in CPU to enable vectorization
explicit DLPackTensor(const Tensor& tensor, LaneType lanes = 1);
explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1);
inline operator const ::DLTensor&() const { return t_; }
......@@ -44,5 +44,7 @@ class DLPackTensor {
ShapeType shape_[DDim::kMaxRank];
};
DLManagedTensor* toDLPack(const phi::DenseTensor& src);
} // namespace framework
} // namespace paddle
......@@ -87,6 +87,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool IsSelectedRowsInputs(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return std::all_of(var_types.begin(),
var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::SELECTED_ROWS;
});
}
bool IsSelectedRowsInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SELECTED_ROWS;
......@@ -155,6 +164,16 @@ int64_t CompatMetaTensor::numel() const {
}
}
bool CompatMetaTensor::is_selected_rows() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->IsType<phi::SelectedRows>();
} else {
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
return var->GetType() == proto::VarType::SELECTED_ROWS;
}
}
bool CompatMetaTensor::is_dense() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
......@@ -182,7 +201,7 @@ DDim CompatMetaTensor::dims() const {
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims();
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
......@@ -260,8 +279,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
var->GetMutable<phi::SelectedRows>()->set_height(dims[0]);
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
......
......@@ -59,6 +59,8 @@ class CompatMetaTensor : public phi::MetaTensor {
bool initialized() const override { return initialized_; };
bool is_selected_rows() const;
bool is_tensor_array() const;
bool is_dense() const;
......
......@@ -148,6 +148,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
using VarType = AutoMixedPrecisionPass::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
const auto& kernels = phi::KernelFactory::Instance().kernels();
if (kernels.count(op_type) == 0) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType precision,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, precision, layout);
support |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, precision, layout);
if (!support) {
const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (const auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ ==
framework::TransToProtoVarType(precision)) {
support = true;
break;
}
}
}
}
return support;
}
inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
inline bool IsFP32AndFP64(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}
inline bool IsFP16AndBFP16(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}
}; // namespace
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
VarType::Type from_type,
VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache) {
if (from_type == to_type) return;
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (cache->count(var_node) == 0) {
// insert cast op between var_node and op_node
std::string cast_input_name = var_node->Var()->Name();
std::string cast_output_name =
var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(var_node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*cache)[var_node] = cast_output_node;
}
op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
IR_NODE_LINK_TO(cache->at(var_node), op_node);
IR_NODE_UNLINK(var_node, op_node);
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
bool support = false;
if (black_list.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Now, only support backend of GPU."));
}
}
return support;
}
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
// ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
black_list_.insert({
// numerically-dangerous
"exp",
"square",
"log",
"mean",
"sum",
"cos_sim",
"softmax_with_cross_entropy",
"sigmoid_cross_entropy_with_logits",
"c_softmax_with_cross_entropy",
"cross_entropy",
"cross_entropy2",
// slower than fp32
"conv2d_transpose",
// default fp32 can avoid return inf when the sum value large than 65504
"reduce_sum",
});
}
void AutoMixedPrecisionPass::Init(Graph* graph) const {
bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed");
if (enable_gpu_mixed) {
backend_ = phi::Backend::GPU;
}
skip_pass_ = !enable_gpu_mixed;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
VLOG(4) << "black_list has ";
for (const auto& name : black_list_) {
VLOG(4) << " - " << name;
}
keep_io_types_ = true;
if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
}
auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size;
subgraphes_.resize(graph_size);
all_op_nodes_.resize(graph_size);
for (size_t i = 0; i < graph_size; i++) {
subgraphes_[i] = graph->GetSubGraph(i);
all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
<< "op nodes";
for (auto* var_node : subgraphes_[i]->Nodes()) {
if (!var_node->IsVar()) continue;
auto var_name = var_node->Var()->Name();
if (real_vars_.count(var_name) == 0) {
real_vars_[var_name] = var_node;
VLOG(4) << var_name << " is in graph " << i;
}
}
}
}
void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the graph "
"should not be nullptr."));
PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the graph "
"should be main graph."));
FusePassBase::Init("auto_mixed_precision", graph);
Init(graph);
VLOG(4) << "Init done";
if (skip_pass_) {
VLOG(3) << "Skip auto_mixed_precision_pass.";
return;
}
SetOpUniqueType();
VLOG(4) << "SetOpUniqueType done";
GetOpPrecision();
VLOG(4) << "GetOpPrecision done";
UpdateOpPrecision();
VLOG(4) << "UpdateOpPrecision done";
SetVarPrecision();
VLOG(4) << "SetVarPrecision done";
ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp();
VLOG(4) << "InsertCastOp done";
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
LOG(INFO) << "The number of ops run at low precision ["
<< op_run_low_precision_.size() << "/" << op_original_type_.size()
<< "]";
}
void AutoMixedPrecisionPass::SetOpUniqueType() const {
int suffix = 0;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
std::string unique_type = op_type + "_" + std::to_string(suffix++);
op_original_type_[unique_type] = op_type;
op_node->Op()->SetType(unique_type);
op_node->Op()->Flush();
VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
}
}
}
void AutoMixedPrecisionPass::RestoreOpOriginType() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
op_node->Op()->SetType(GetOpOriginalType(op_type));
op_node->Op()->Flush();
VLOG(4) << "restore op type: " << op_type << " ---> "
<< op_node->Op()->Type();
}
}
}
inline std::string AutoMixedPrecisionPass::GetOpOriginalType(
const std::string& op_type) const {
if (op_original_type_.count(op_type)) {
return op_original_type_.at(op_type);
}
return op_type;
}
void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_node->Op()->HasAttr("in_dtype")) {
auto* var_node = op_node->inputs[0];
auto* real_var_node = real_vars_[var_node->Var()->Name()];
if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
op_node->Op()->SetAttr(
"in_dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with in_dtype attr: " << op_type << " ( "
<< static_cast<int>(real_var_node->Var()->GetDataType())
<< " --->" << static_cast<int>(low_precision_) << " )";
}
}
if (op_run_low_precision_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFP32AndFP64(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr(
"dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(low_precision_) << " )";
}
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (IsFP32AndFP64(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
<< out_dtype << " --->" << static_cast<int>(low_precision_)
<< " )";
}
}
}
}
}
void AutoMixedPrecisionPass::GetOpPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
bool support_low_precision = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_low_precision = !keep_io_types_;
} else {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
}
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision =
support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
}
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
if (low_precision_ == phi::DataType::FLOAT16) {
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
} else if (low_precision_ == phi::DataType::BFLOAT16) {
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
}
}
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
if (support_low_precision) {
op_run_low_precision_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at low precision";
} else {
VLOG(4) << "support precision: " << op_type
<< " not run at low precision";
}
}
}
}
void AutoMixedPrecisionPass::UpdateOpPrecision() const {
std::unordered_set<std::string> vars_should_not_low_precision;
// var -> the var's all input op
std::unordered_map<std::string, std::vector<Node*>> var_input_ops;
auto GetVarInputOps = [&] {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "fetch") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
for (auto* var_node : op_node->outputs) {
CHECK_EQ(var_node->IsVar(), true);
if (var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
var_input_ops[var_node->Var()->Name()].push_back(op_node);
VLOG(4) << "var input ops: " << var_node->Var()->Name()
<< " is output of " << op_type;
}
// the select_input op's input var should not convert to low precision.
// when op's output var is select_input op's input var, the op should
// not run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(out_var_node)) continue;
vars_should_not_low_precision.insert(out_var_node->Var()->Name());
}
}
}
}
};
GetVarInputOps();
bool precision_updated = false;
do {
precision_updated = false;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_low_precision.count(
real_in_var_node->Var()->Name())) {
op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not run at low precision.";
break;
}
}
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_low_precision = false;
const auto& input_op_nodes =
var_input_ops[real_out_var_node->Var()->Name()];
if (vars_should_not_low_precision.count(
real_out_var_node->Var()->Name())) {
not_run_low_precision = true;
} else {
for (auto* node : input_op_nodes) {
if (op_run_low_precision_.count(node->Op()->Type()) == 0) {
not_run_low_precision = true;
break;
}
}
}
if (not_run_low_precision) {
op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not run at low precision.";
break;
}
}
}
}
} while (precision_updated);
}
// special ops, its weights should not be low precision.
bool AutoMixedPrecisionPass::InputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
bool AutoMixedPrecisionPass::OutputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
void AutoMixedPrecisionPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
continue;
}
if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
vars_convert_to_low_precision_.insert(in_var_name);
}
}
}
if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_out_var_node)) continue;
if (OutputVarsNotConvert(op_node, out_var_name)) continue;
real_out_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
if (real_out_var_node->Var()->Persistable()) {
vars_convert_to_low_precision_.insert(out_var_name);
}
}
}
}
}
// This code used to precess vars with the same name. Vars with the same
// name should have the same data type.
for (auto* subgraph : subgraphes_) {
for (auto* var_node : subgraph->Nodes()) {
if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
auto var_name = var_node->Var()->Name();
if (vars_convert_to_low_precision_.count(var_name)) {
var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
}
}
}
}
void AutoMixedPrecisionPass::ConvertWeightsData() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the scope "
"should not be null."));
auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) {
if (vars_convert_to_low_precision_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to low precision";
auto* var = scope->FindLocalVar(var_name);
CHECK_EQ(var->IsType<phi::DenseTensor>(), true);
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor low_precision_tensor;
low_precision_tensor.Resize(origin_tensor->dims());
low_precision_tensor.set_type(low_precision_);
if (low_precision_ == phi::DataType::FLOAT16) {
auto* low_precision_data =
low_precision_tensor.mutable_data<phi::dtype::float16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
}
}
} else if (low_precision_ == phi::DataType::BFLOAT16) {
auto* low_precision_data =
low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
low_precision_data[i] =
static_cast<phi::dtype::bfloat16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
low_precision_data[i] =
static_cast<phi::dtype::bfloat16>(origin_data[i]);
}
}
}
origin_tensor->clear();
paddle::framework::TensorCopySync(
low_precision_tensor, phi::CPUPlace{}, origin_tensor);
}
}
}
void AutoMixedPrecisionPass::InsertCastOp() const {
int suffix = 0;
std::unordered_map<Node*, Node*> cache;
for (size_t i = 0; i < all_op_nodes_.size(); i++) {
auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
CHECK_NOTNULL(block_desc);
for (auto* op_node : all_op_nodes_[i]) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
VLOG(4) << "process op: " << op_type
<< " run low precision: " << op_run_low_precision_.count(op_type);
auto inputs = op_node->inputs;
for (auto* in_var_node : inputs) {
if (!in_var_node->IsVar()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_type = real_in_var_node->Var()->GetDataType();
VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type;
if (IsFP32AndFP64(in_var_type) &&
op_run_low_precision_.count(op_type)) {
auto to_type = framework::TransToProtoVarType(low_precision_);
auto* prev_op =
in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
in_var_node->Var()->SetDataType(to_type);
prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
prev_op->Op()->Flush();
} else {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
to_type,
block_desc,
&suffix,
&cache);
}
} else if (IsFP16AndBFP16(in_var_type) &&
op_run_low_precision_.count(op_type) == 0) {
auto to_type = VarType::FP32;
auto* prev_op =
in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
in_var_node->Var()->SetDataType(to_type);
prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
prev_op->Op()->Flush();
} else {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
to_type,
block_desc,
&suffix,
&cache);
}
}
}
// Special op.
// fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
// have same name.
if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
VLOG(4) << "insert number of cast op: " << cache.size();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(auto_mixed_precision_pass,
paddle::framework::ir::AutoMixedPrecisionPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace framework {
namespace ir {
class AutoMixedPrecisionPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
AutoMixedPrecisionPass() = default;
~AutoMixedPrecisionPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
private:
void Init(Graph* graph) const;
void SetDefaultBlacklist() const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
inline std::string GetOpOriginalType(const std::string& op_type) const;
void GetOpPrecision() const;
void UpdateOpPrecision() const;
void InsertCastOp() const;
void ProcessOpWithDtypeAttr() const;
bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const;
bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const;
void SetVarPrecision() const;
void ConvertWeightsData() const;
private:
mutable bool skip_pass_{false};
mutable bool keep_io_types_{false};
// float16 or bfloat16 now
mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU};
mutable std::unordered_set<std::string> black_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
// var name -> real var node
mutable std::unordered_map<std::string, Node*> real_vars_;
// subgraph id -> all op nodes in subgraph
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at low precision
mutable std::unordered_set<std::string> op_run_low_precision_;
mutable std::unordered_set<std::string> vars_convert_to_low_precision_;
};
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -142,6 +142,9 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
}
out_desc->SetShape(out_shape);
out_desc->SetPersistable(true);
auto *var_desc_out = op_node->Op()->Block()->Var(out_name);
var_desc_out->SetShape(out_shape);
var_desc_out->SetPersistable(true);
auto *global_out_tensor = scope->Var(out_name)->GetMutable<LoDTensor>();
*global_out_tensor = *local_out_tensor;
}
......
......@@ -29,6 +29,11 @@ void FillConstData(LoDTensor* out_t, T value) {
}
void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
// Not support
if (with_dynamic_shape) {
return;
}
FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector;
auto fill_constant_op =
......
......@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program,
}
} else {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
}
......@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block,
const int64_t end_op_index)
: main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
// TODO(levi): delete this interface after when we can convert all
......
......@@ -1045,6 +1045,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
->assert_has_n_outputs(1)
->assert_is_op_input("squeeze2", "X");
auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr())
->assert_is_op("squeeze2")
......
......@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(nodes.size(), 5UL);
}
TEST(GraphTest, WriteAfterRead) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, WriteAfterWrite) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, TestException) {
ProgramDesc prog;
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
......@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(1)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1);
prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp();
......@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(2)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1);
prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
......@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1);
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
ASSERT_EQ(n->outputs.size(), 1UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
ASSERT_EQ(n->inputs.size(), 1UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
// Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2);
control_dep1 = nullptr;
control_dep2 = nullptr;
for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 1UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
ASSERT_EQ(n->inputs.size(), 1UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
// Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone();
......
......@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
copy_node(node);
}
}
result.ResolveHazard(created);
}
} // namespace ir
......
......@@ -183,5 +183,6 @@ void NaiveExecutor::ResetTrtOps(int num) {
}
#endif
}
} // namespace framework
} // namespace paddle
......@@ -50,7 +50,7 @@ USE_OP_ITSELF(concat_grad);
USE_OP_ITSELF(elementwise_mul_grad);
USE_OP_ITSELF(sigmoid_grad);
USE_OP_ITSELF(tanh_grad);
USE_OP(sum);
USE_OP_ITSELF(sum);
USE_OP_ITSELF(slice_grad);
USE_OP_ITSELF(lookup_table_grad);
USE_OP_ITSELF(sqrt);
......@@ -101,6 +101,7 @@ PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_n, GPU, ALL_LAYOUT);
namespace paddle {
namespace framework {
......
......@@ -512,6 +512,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool IsSelectedRowsInputs(const std::string& name) const override {
auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::SelectedRows>();
});
}
bool IsSelectedRowsInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SelectedRows>();
......
......@@ -146,6 +146,48 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
kernel_signature_(std::move(kernel_signature)),
phi_kernel_(phi_kernel) {}
#ifdef PADDLE_WITH_MLU
static void tokenize(const std::string& ops,
char delim,
std::unordered_set<std::string>* op_set) {
std::string::size_type beg = 0;
for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
++end) {
op_set->insert(ops.substr(beg, end - beg));
beg = end + 1;
}
op_set->insert(ops.substr(beg));
}
static bool is_in_mlu_black_list(const std::string& op_name) {
static bool inited = false;
static std::unordered_set<std::string> mlu_black_list;
static std::mutex s_mtx;
if (!inited) {
std::lock_guard<std::mutex> guard(s_mtx);
if (!inited) {
if (std::getenv("MLU_BLACK_LIST") != nullptr) {
std::string ops(std::getenv("MLU_BLACK_LIST"));
tokenize(ops, ',', &mlu_black_list);
}
inited = true;
VLOG(3) << "MLU Black List: ";
for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end();
++iter) {
VLOG(3) << *iter << " ";
}
}
}
if (mlu_black_list.find(op_name) != mlu_black_list.end()) {
return true;
}
return false;
}
#endif
template <typename VarType>
PreparedOp PrepareImpl(
const NameVarMap<VarType>& ins,
......@@ -194,6 +236,12 @@ PreparedOp PrepareImpl(
#endif
#ifdef PADDLE_WITH_MLU
if (is_in_mlu_black_list(op.Type())) {
expected_kernel_key.place_ = platform::CPUPlace();
}
#endif
bool has_phi_kernel = false;
const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());
......
......@@ -38,8 +38,7 @@ void Analyzer::RunAnalysis(Argument *argument) {
if (!disable_logs) {
string::PrettyLogH1("--- Running analysis [%s]", pass);
}
if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass")
continue;
if (!argument->enable_ir_optim() && pass == "ir_analysis_pass") continue;
auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr,
......
......@@ -31,7 +31,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
Argument argument;
argument.SetDisableLogs(false);
argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetEnableAnalysisOptim(false);
argument.SetEnableIrOptim(false);
argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass",
"ir_analysis_pass",
......@@ -44,7 +44,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) {
Argument argument;
argument.SetDisableLogs(false);
argument.SetEnableAnalysisOptim(false);
argument.SetEnableIrOptim(false);
argument.SetTensorRtMaxBatchSize(3);
argument.SetTensorRtWorkspaceSize(1 << 20);
argument.SetModelDir(FLAGS_inference_model_dir);
......
......@@ -42,8 +42,6 @@ namespace paddle {
namespace inference {
namespace analysis {
using framework::ir::Graph;
#ifdef PADDLE_WITH_MKLDNN
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
......@@ -148,7 +146,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
// For JITLayer
DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);
......@@ -362,6 +360,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
private:
std::unordered_set<std::string> valid_fields_;
......
......@@ -153,25 +153,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return *var->GetMutable<T>();
}
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string &model_path) {
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
fin.is_open(),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file exists",
model_path));
fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' ');
fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(buffer);
return program_desc;
}
static bool FileExists(const std::string &filepath) {
std::ifstream file(filepath);
bool exists = file.is_open();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment