Commit dbe08e9b authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.2

parent b5499578
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -36,15 +37,6 @@ using string::PrettyLogEndl; ...@@ -36,15 +37,6 @@ using string::PrettyLogEndl;
using string::Style; using string::Style;
IRPassManager::IRPassManager(Argument *argument) { IRPassManager::IRPassManager(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, main_program);
graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
if (argument->Has("scope")) {
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr,
platform::errors::PreconditionNotMet(
"The scope ptr should not be nullptr."));
graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}
disable_logs_ = argument->disable_logs(); disable_logs_ = argument->disable_logs();
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes); ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
...@@ -95,10 +87,14 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -95,10 +87,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape(); argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision())); pass->Set("model_precision", new int(argument->model_precision()));
pass->Set( pass->Set(
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir(); std::string optim_cache_dir = argument->optim_cache_dir();
...@@ -302,42 +298,18 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -302,42 +298,18 @@ void IRPassManager::CreatePasses(Argument *argument,
} }
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) { std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph.get(), graph.get(), platform::errors::InvalidArgument("Graph cannot be null."));
platform::errors::PreconditionNotMet("Graph cannot be NULL."));
// Apply all the passes // Apply all the passes
for (const auto &pass : passes_) { for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) { if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
} }
// delete_fill_constant_op_pass is not apply under trt dynamic shape
if (pass->Type() == "delete_fill_constant_op_pass") {
bool use_dynamic = pass->Get<bool>("with_dynamic_shape");
if (use_dynamic) continue;
}
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
} }
return graph; return graph;
} }
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
graph->reset(pass->Apply(the_graph));
return *desc.Proto();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -48,15 +48,9 @@ class IRPassManager final { ...@@ -48,15 +48,9 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph); std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; }
private: private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes); void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<Pass>> passes_; std::vector<std::unique_ptr<Pass>> passes_;
bool disable_logs_{false}; bool disable_logs_{false};
}; };
......
...@@ -94,13 +94,13 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -94,13 +94,13 @@ void OutputProcess(framework::ir::Graph *graph,
backend, backend,
precision, precision,
blacklist)) { blacklist)) {
AddCastOp(graph, InsertCastOp(graph,
var_node, var_node,
next_op, next_op,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
to_type, to_type,
&suffix,
block_desc, block_desc,
&suffix,
&var_to_cast_op_map); &var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32); var_node->Var()->SetDataType(framework::proto::VarType::FP32);
} }
......
...@@ -13,7 +13,7 @@ cc_library( ...@@ -13,7 +13,7 @@ cc_library(
cc_library( cc_library(
convert_to_mixed_precision convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass) DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
cc_library( cc_library(
ir_params_sync_among_devices_pass ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc SRCS ir_params_sync_among_devices_pass.cc
...@@ -30,17 +30,6 @@ cc_library( ...@@ -30,17 +30,6 @@ cc_library(
inference_op_replace_pass inference_op_replace_pass
SRCS inference_op_replace_pass.cc SRCS inference_op_replace_pass.cc
DEPS analysis_pass graph_to_program_pass) DEPS analysis_pass graph_to_program_pass)
if(WITH_TESTING)
cc_library(
ir_graph_clean_pass
SRCS ir_graph_clean_pass.cc
DEPS analysis_pass gtest)
else()
cc_library(
ir_graph_clean_pass
SRCS ir_graph_clean_pass.cc
DEPS analysis_pass)
endif()
cc_library( cc_library(
analysis_passes analysis_passes
...@@ -52,8 +41,7 @@ cc_library( ...@@ -52,8 +41,7 @@ cc_library(
memory_optim_pass memory_optim_pass
convert_to_mixed_precision convert_to_mixed_precision
inference_op_replace_pass inference_op_replace_pass
ir_graph_to_program_pass ir_graph_to_program_pass)
ir_graph_clean_pass)
set(analysis_deps set(analysis_deps
${analysis_deps} analysis_passes subgraph_detector ${analysis_deps} analysis_passes subgraph_detector
......
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -30,20 +28,52 @@ namespace paddle { ...@@ -30,20 +28,52 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
void Run();
private:
void LoadModel();
void SaveMixedModel();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
};
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& blacklist); const std::unordered_set<std::string>& black_list);
void AddCastOp( void InsertCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
framework::ir::Node* node, framework::ir::Node* var_node,
framework::ir::Node* next_op, framework::ir::Node* op_node,
framework::proto::VarType::Type from_type, framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type, framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map); int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited);
void ConvertToMixedPrecision(const std::string& model_file, void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
...@@ -51,8 +81,8 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -51,8 +81,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const std::string& mixed_params_file, const std::string& mixed_params_file,
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types = true, bool keep_io_types,
std::unordered_set<std::string> black_list = {}); const std::unordered_set<std::string>& black_list);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -40,7 +40,7 @@ void InferenceOpReplacePass::RunImpl(Argument* argument) { ...@@ -40,7 +40,7 @@ void InferenceOpReplacePass::RunImpl(Argument* argument) {
} }
std::string InferenceOpReplacePass::repr() const { std::string InferenceOpReplacePass::repr() const {
return "inference-op-replace-pass"; return "inference_op_replace_pass";
} }
} // namespace analysis } // namespace analysis
......
...@@ -105,7 +105,7 @@ void IrAnalysisPass::CollectFusionStatis(Argument* argument) { ...@@ -105,7 +105,7 @@ void IrAnalysisPass::CollectFusionStatis(Argument* argument) {
framework::ir::kFuseStatisAttr)); framework::ir::kFuseStatisAttr));
} }
std::string IrAnalysisPass::repr() const { return "ir-analysis-pass"; } std::string IrAnalysisPass::repr() const { return "ir_analysis_pass"; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -64,7 +64,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { ...@@ -64,7 +64,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
"set.")); "set."));
} }
auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program())); auto graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(argument->main_program()));
argument->SetMainGraph(graph.release()); argument->SetMainGraph(graph.release());
auto *scope_ptr = argument->scope_ptr(); auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr, PADDLE_ENFORCE_NOT_NULL(scope_ptr,
...@@ -125,7 +126,7 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel( ...@@ -125,7 +126,7 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
} }
} }
std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; } std::string IrGraphBuildPass::repr() const { return "ir_graph_build_pass"; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -31,7 +31,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { ...@@ -31,7 +31,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
new int(argument->memory_optim_sort_kind())); new int(argument->memory_optim_sort_kind()));
} }
std::unique_ptr<Graph> graph(argument->main_graph_ptr()); std::unique_ptr<framework::ir::Graph> graph(argument->main_graph_ptr());
// Direct using ProgramDesc desc(argument->main_program()) may cause // Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information. // incomplete copies of information.
......
...@@ -28,7 +28,7 @@ class IrGraphToProgramPass : public AnalysisPass { ...@@ -28,7 +28,7 @@ class IrGraphToProgramPass : public AnalysisPass {
public: public:
void RunImpl(Argument *argument) override; void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir-graph-to-param-pass"; } std::string repr() const override { return "ir_graph_to_param_pass"; }
}; };
} // namespace analysis } // namespace analysis
......
...@@ -169,7 +169,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { ...@@ -169,7 +169,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
} }
std::string IrParamsSyncAmongDevicesPass::repr() const { std::string IrParamsSyncAmongDevicesPass::repr() const {
return "ir-params-sync-among-devices-pass"; return "ir_params_sync_among_devices_pass";
} }
} // namespace analysis } // namespace analysis
......
...@@ -295,7 +295,7 @@ void UpdateOpDescsByReuse( ...@@ -295,7 +295,7 @@ void UpdateOpDescsByReuse(
} }
} }
std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; } std::string MemoryOptimizePass::repr() const { return "memory_optimize_pass"; }
void MemoryOptimizePass::RunImpl(Argument* argument) { void MemoryOptimizePass::RunImpl(Argument* argument) {
// Memory optimization. // Memory optimization.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
...@@ -34,8 +33,6 @@ PassRegistry::PassRegistry() { ...@@ -34,8 +33,6 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrAnalysisPass)); std::unique_ptr<AnalysisPass>(new IrAnalysisPass));
passes_.emplace("ir_graph_build_pass", passes_.emplace("ir_graph_build_pass",
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass)); std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_graph_clean_pass",
std::unique_ptr<AnalysisPass>(new IrInferCleanGraphPass));
passes_.emplace("memory_optimize_pass", passes_.emplace("memory_optimize_pass",
std::unique_ptr<AnalysisPass>(new MemoryOptimizePass)); std::unique_ptr<AnalysisPass>(new MemoryOptimizePass));
passes_.emplace( passes_.emplace(
......
...@@ -85,15 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path, ...@@ -85,15 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update(); Update();
} }
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) { int device_id,
Precision precision_mode) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_ = true; use_gpu_ = true;
memory_pool_init_size_mb_ = memory_pool_init_size_mb; memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_; FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id; gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_mixed_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function.";
}
#else #else
LOG(ERROR) << "Please compile with gpu to EnableGpu()"; LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
use_gpu_ = false; use_gpu_ = false;
#endif #endif
...@@ -279,7 +293,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) { ...@@ -279,7 +293,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) { if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key)); "invalid key %s in IPU config: ", key));
} }
switch (ipu_config_mapper_.at(key)) { switch (ipu_config_mapper_.at(key)) {
case ipu_config_code::ipu_device_num: case ipu_config_code::ipu_device_num:
...@@ -315,7 +329,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) { ...@@ -315,7 +329,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key)); "invalid key %s in IPU config", key));
break; break;
} }
} }
...@@ -372,8 +386,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -372,8 +386,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_); CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_); CP_MEMBER(memory_pool_init_size_mb_);
// Mixed related. // Mixed precision related.
CP_MEMBER(mixed_black_list_); CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
// TensorRT related. // TensorRT related.
...@@ -740,13 +756,7 @@ void AnalysisConfig::Update() { ...@@ -740,13 +756,7 @@ void AnalysisConfig::Update() {
((use_custom_device() ^ pass_builder_->use_custom_device()))) { ((use_custom_device() ^ pass_builder_->use_custom_device()))) {
if (use_gpu()) { if (use_gpu()) {
pass_builder_.reset(new GpuPassStrategy); pass_builder_.reset(new GpuPassStrategy);
if (use_tensorrt_) {
// Append after the Affine_channel_conv_fuse pass.
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
}
} else if (use_ipu()) { } else if (use_ipu()) {
VLOG(1) << "IpuPassStrategy has been used for new.";
pass_builder_.reset(new IpuPassStrategy); pass_builder_.reset(new IpuPassStrategy);
} else if (use_xpu()) { } else if (use_xpu()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -946,9 +956,6 @@ void AnalysisConfig::Update() { ...@@ -946,9 +956,6 @@ void AnalysisConfig::Update() {
"but did not have the option -DWITH_CUSTOM_DEVICE compiled.")); "but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif #endif
} }
if (ir_debug_) {
pass_builder()->TurnOnDebug();
}
} }
std::string AnalysisConfig::SerializeInfoCache() { std::string AnalysisConfig::SerializeInfoCache() {
...@@ -960,6 +967,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -960,6 +967,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << calibration_file_path_; ss << calibration_file_path_;
ss << use_gpu_; ss << use_gpu_;
ss << enable_gpu_mixed_;
ss << use_external_stream_; ss << use_external_stream_;
ss << exec_stream_; ss << exec_stream_;
ss << use_fc_padding_; ss << use_fc_padding_;
...@@ -1167,6 +1175,7 @@ std::string AnalysisConfig::Summary() { ...@@ -1167,6 +1175,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) { if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size", os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"}); std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow( os.InsertRow(
...@@ -1360,7 +1369,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() { ...@@ -1360,7 +1369,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() {
return trt_allow_build_at_runtime_; return trt_allow_build_at_runtime_;
} }
void AnalysisConfig::Exp_SetBlackListOpsForMixedModel( void AnalysisConfig::Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string> &black_list) { const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list; mixed_black_list_ = black_list;
} }
......
...@@ -1065,7 +1065,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1065,7 +1065,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetUseFcPadding(config_.use_fc_padding()); argument_.SetUseFcPadding(config_.use_fc_padding());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_); argument_.SetEnableIrOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
...@@ -1210,53 +1210,57 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1210,53 +1210,57 @@ void AnalysisPredictor::PrepareArgument() {
} }
#endif #endif
auto passes = config_.pass_builder()->AllPasses(); auto *pass_builder = config_.pass_builder();
if (model_precision_ != phi::DataType::FLOAT32) { if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_ LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU " << ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now."; "backend is supported for now.";
passes.clear(); pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) { if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) { for (const auto &pass : kTrtLowerPrecisionPasses) {
passes.push_back(pass); if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
} }
} else if (config_.use_gpu()) { } else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) { for (const auto &pass : kGpuLowerPrecisionPasses) {
passes.push_back(pass); if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
} }
} }
const auto &deleted_passes = config_.pass_builder()->GetAllDeletedPasses();
for (const auto &it : deleted_passes) {
auto iterator = std::find(passes.begin(), passes.end(), it);
if (iterator != passes.end()) {
passes.erase(iterator);
}
} }
if (config_.ir_debug_) { if (!config_.ir_optim()) {
auto it = std::begin(passes); argument_.SetEnableIrOptim(false);
while (it != std::end(passes)) { if (config_.enable_gpu_mixed_) {
if (*it != "graph_viz_pass") { argument_.SetEnableIrOptim(true);
it = passes.insert(it + 1, "graph_viz_pass"); pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
} else { } else {
++it; LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
}
} }
} else {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
} }
if (config_.enable_gpu_mixed_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
} }
if (!config_.ir_optim()) {
passes.clear();
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed";
} }
argument_.SetDisableLogs(config_.glog_info_disabled()); argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(passes); argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses()); argument_.SetAnalysisPasses(pass_builder->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get()); argument_.SetScopeNotOwned(scope_.get());
// mixed precison. // mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_)); argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_); argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
} }
// NOTE All the members in AnalysisConfig should be copied to Argument. // NOTE All the members in AnalysisConfig should be copied to Argument.
...@@ -2107,7 +2111,9 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone(void *stream) { ...@@ -2107,7 +2111,9 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone(void *stream) {
} }
x->predictor_stream_ = stream; x->predictor_stream_ = stream;
x->Init(scope_, inference_program_); x->Init(scope_, inference_program_);
#ifdef PADDLE_WITH_TENSORRT
x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_); x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_);
#endif
return std::unique_ptr<PaddlePredictor>(x); return std::unique_ptr<PaddlePredictor>(x);
} }
......
...@@ -604,10 +604,8 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -604,10 +604,8 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses(); auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes); predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses({"ir_graph_clean_pass", predictor_.argument_.SetAnalysisPasses(
"ir_analysis_pass", {"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
"memory_optimize_pass",
"ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_); predictor_.argument_.SetQuantVarScales(scales_);
} }
......
...@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig {
/// ///
/// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB. /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
/// \param device_id device_id the GPU card to use (default is 0). /// \param device_id device_id the GPU card to use (default is 0).
/// \param precision the precision used in Paddle-GPU inference.
/// ///
void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0); void EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
/// ///
/// \brief Turn off GPU. /// \brief Turn off GPU.
/// ///
...@@ -967,7 +971,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -967,7 +971,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note /// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist. /// that the blacklist must be the same as the model conversion blacklist.
/// ///
void Exp_SetBlackListOpsForMixedModel( void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; } void SetApplyOptim(bool value) { apply_optim_ = value; }
...@@ -987,13 +991,15 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -987,13 +991,15 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string params_file_; mutable std::string params_file_;
mutable std::string calibration_file_path_; mutable std::string calibration_file_path_;
// Mixed precision. // Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_; std::unordered_set<std::string> mixed_black_list_;
// GPU related. // GPU related.
bool use_gpu_{false}; bool use_gpu_{false};
int gpu_device_id_{0}; int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_mixed_{false};
bool thread_local_stream_{false}; bool thread_local_stream_{false};
bool use_cudnn_{false}; bool use_cudnn_{false};
......
...@@ -171,8 +171,9 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{ ...@@ -171,8 +171,9 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", "gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass", "gpu_cpu_map_matmul_v2_to_matmul_pass",
"gpu_cpu_map_matmul_to_mul_pass",
"fc_fuse_pass", "fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass", "embedding_eltwise_layernorm_fuse_pass",
"runtime_context_cache_pass", "runtime_context_cache_pass",
}; };
...@@ -227,9 +228,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -227,9 +228,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", "constant_folding_pass", //
// following pass should be located in the last, since it will // following pass should be located in the last, since it will
// work on all fused ops. // work on all fused ops.
"auto_mixed_precision_pass", //
"runtime_context_cache_pass" "runtime_context_cache_pass"
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment