Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
c05cfa89
Commit
c05cfa89
authored
Jul 09, 2017
by
Guolin Ke
Browse files
clean code for config.
parent
43d50370
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
119 additions
and
132 deletions
+119
-132
.travis/test.sh
.travis/test.sh
+2
-2
include/LightGBM/config.h
include/LightGBM/config.h
+14
-26
src/io/config.cpp
src/io/config.cpp
+103
-104
No files found.
.travis/test.sh
View file @
c05cfa89
...
...
@@ -59,7 +59,7 @@ if [[ ${TASK} == "gpu" ]]; then
export
PATH
=
"
$AMDAPPSDK
/include/:
$PATH
"
export
BOOST_ROOT
=
"
$HOME
/miniconda/"
LGB_VER
=
$(
head
-n
1 VERSION.txt
)
sed
-i
's/std::string
device_typ
e = "cpu";/std::string
device_typ
e = "gpu";/'
../include/LightGBM/config.h
sed
-i
's/
const
std::string
kDefaultDevic
e = "cpu";/
const
std::string
kDefaultDevic
e = "gpu";/'
../include/LightGBM/config.h
cd
$TRAVIS_BUILD_DIR
/python-package
&&
python setup.py sdist
||
exit
-1
cd
$TRAVIS_BUILD_DIR
/python-package/dist
&&
pip
install
lightgbm-
$LGB_VER
.tar.gz
-v
--install-option
=
--gpu
||
exit
-1
cd
$TRAVIS_BUILD_DIR
&&
pytest tests/python_package_test
||
exit
-1
...
...
@@ -73,7 +73,7 @@ if [[ ${TASK} == "mpi" ]]; then
cmake
-DUSE_MPI
=
ON ..
elif
[[
${
TASK
}
==
"gpu"
]]
;
then
cmake
-DUSE_GPU
=
ON
-DBOOST_ROOT
=
"
$HOME
/miniconda/"
-DOpenCL_INCLUDE_DIR
=
$AMDAPPSDK
/include/ ..
sed
-i
's/std::string
device_typ
e = "cpu";/std::string
device_typ
e = "gpu";/'
../include/LightGBM/config.h
sed
-i
's/
const
std::string
kDefaultDevic
e = "cpu";/
const
std::string
kDefaultDevic
e = "gpu";/'
../include/LightGBM/config.h
else
cmake ..
fi
...
...
include/LightGBM/config.h
View file @
c05cfa89
...
...
@@ -16,6 +16,11 @@
namespace
LightGBM
{
const
std
::
string
kDefaultTreeLearnerType
=
"serial"
;
const
std
::
string
kDefaultDevice
=
"cpu"
;
const
std
::
string
kDefaultBoostingType
=
"gbdt"
;
const
std
::
string
kDefaultObjectiveType
=
"regression"
;
/*!
* \brief The interface for Config
*/
...
...
@@ -38,7 +43,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline
bool
GetString
(
inline
static
bool
GetString
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
const
std
::
string
&
name
,
std
::
string
*
out
);
...
...
@@ -49,7 +54,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline
bool
GetInt
(
inline
static
bool
GetInt
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
const
std
::
string
&
name
,
int
*
out
);
...
...
@@ -60,7 +65,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline
bool
GetDouble
(
inline
static
bool
GetDouble
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
const
std
::
string
&
name
,
double
*
out
);
...
...
@@ -71,7 +76,7 @@ public:
* \param out Value will assign to out if key exists
* \return True if key exists
*/
inline
bool
GetBool
(
inline
static
bool
GetBool
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
const
std
::
string
&
name
,
bool
*
out
);
...
...
@@ -135,7 +140,7 @@ public:
* And add an prefix "name:" while using column name
* Note: when using Index, it doesn't count the label index */
std
::
string
categorical_column
=
""
;
std
::
string
device_type
=
"cpu"
;
std
::
string
device_type
=
kDefaultDevice
;
/*! \brief Set to true if want to use early stop for the prediction */
bool
pred_early_stop
=
false
;
...
...
@@ -145,9 +150,6 @@ public:
double
pred_early_stop_margin
=
10.0
f
;
LIGHTGBM_EXPORT
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
private:
void
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
};
/*! \brief Config for objective function */
...
...
@@ -246,15 +248,10 @@ public:
double
other_rate
=
0.1
f
;
// only used for the regression. Will boost from the average labels.
bool
boost_from_average
=
true
;
std
::
string
tree_learner_type
=
"serial"
;
std
::
string
device_type
=
"cpu"
;
std
::
string
tree_learner_type
=
kDefaultTreeLearnerType
;
std
::
string
device_type
=
kDefaultDevice
;
TreeConfig
tree_config
;
LIGHTGBM_EXPORT
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
private:
void
GetTreeLearnerType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
void
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
};
/*! \brief Config for Network */
...
...
@@ -278,25 +275,16 @@ public:
bool
is_parallel
=
false
;
bool
is_parallel_find_bin
=
false
;
IOConfig
io_config
;
std
::
string
boosting_type
=
"gbdt"
;
std
::
string
boosting_type
=
kDefaultBoostingType
;
BoostingConfig
boosting_config
;
std
::
string
objective_type
=
"regression"
;
std
::
string
objective_type
=
kDefaultObjectiveType
;
ObjectiveConfig
objective_config
;
std
::
vector
<
std
::
string
>
metric_types
;
MetricConfig
metric_config
;
std
::
string
convert_model_language
=
""
;
LIGHTGBM_EXPORT
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
private:
void
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
void
GetObjectiveType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
void
GetMetricType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
void
GetTaskType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
void
CheckParamConflict
();
};
...
...
src/io/config.cpp
View file @
c05cfa89
...
...
@@ -32,49 +32,10 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return
params
;
}
void
OverallConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
// load main config types
GetInt
(
params
,
"num_threads"
,
&
num_threads
);
GetString
(
params
,
"convert_model_language"
,
&
convert_model_language
);
// generate seeds by seed.
if
(
GetInt
(
params
,
"seed"
,
&
seed
))
{
Random
rand
(
seed
);
int
int_max
=
std
::
numeric_limits
<
short
>::
max
();
io_config
.
data_random_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
bagging_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
drop_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
tree_config
.
feature_fraction_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
}
GetTaskType
(
params
);
GetBoostingType
(
params
);
GetObjectiveType
(
params
);
GetMetricType
(
params
);
// sub-config setup
network_config
.
Set
(
params
);
io_config
.
Set
(
params
);
boosting_config
.
Set
(
params
);
objective_config
.
Set
(
params
);
metric_config
.
Set
(
params
);
// check for conflicts
CheckParamConflict
();
if
(
io_config
.
verbosity
==
1
)
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Info
);
}
else
if
(
io_config
.
verbosity
==
0
)
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Warning
);
}
else
if
(
io_config
.
verbosity
>=
2
)
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Debug
);
}
else
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Fatal
);
}
}
void
OverallConfig
::
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"boosting_type"
,
&
value
))
{
std
::
string
boosting_type
=
kDefaultBoostingType
;
if
(
ConfigBase
::
GetString
(
params
,
"boosting_type"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"gbdt"
)
||
value
==
std
::
string
(
"gbrt"
))
{
boosting_type
=
"gbdt"
;
...
...
@@ -86,19 +47,23 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
Log
::
Fatal
(
"Unknown boosting type %s"
,
value
.
c_str
());
}
}
return
boosting_type
;
}
void
OverallConfig
::
GetObjectiveType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
GetObjectiveType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"objective"
,
&
value
))
{
std
::
string
objective_type
=
kDefaultObjectiveType
;
if
(
ConfigBase
::
GetString
(
params
,
"objective"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
objective_type
=
value
;
}
return
objective_type
;
}
void
OverallConfig
::
GetMetricType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
vector
<
std
::
string
>
GetMetricType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"metric"
,
&
value
))
{
std
::
vector
<
std
::
string
>
metric_types
;
if
(
ConfigBase
::
GetString
(
params
,
"metric"
,
&
value
))
{
// clear old metrics
metric_types
.
clear
();
// to lower
...
...
@@ -118,12 +83,13 @@ void OverallConfig::GetMetricType(const std::unordered_map<std::string, std::str
}
metric_types
.
shrink_to_fit
();
}
return
metric_types
;
}
void
OverallConfig
::
GetTaskType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
TaskType
GetTaskType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"task"
,
&
value
))
{
TaskType
task_type
=
TaskType
::
kTrain
;
if
(
ConfigBase
::
GetString
(
params
,
"task"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"train"
)
||
value
==
std
::
string
(
"training"
))
{
task_type
=
TaskType
::
kTrain
;
...
...
@@ -136,10 +102,88 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
Log
::
Fatal
(
"Unknown task type %s"
,
value
.
c_str
());
}
}
return
task_type
;
}
void
OverallConfig
::
CheckParamConflict
()
{
std
::
string
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
std
::
string
device_type
=
kDefaultDevice
;
if
(
ConfigBase
::
GetString
(
params
,
"device"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"cpu"
))
{
device_type
=
"cpu"
;
}
else
if
(
value
==
std
::
string
(
"gpu"
))
{
device_type
=
"gpu"
;
}
else
{
Log
::
Fatal
(
"Unknown device type %s"
,
value
.
c_str
());
}
}
return
device_type
;
}
std
::
string
GetTreeLearnerType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
std
::
string
tree_learner_type
=
kDefaultTreeLearnerType
;
if
(
ConfigBase
::
GetString
(
params
,
"tree_learner"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"serial"
))
{
tree_learner_type
=
"serial"
;
}
else
if
(
value
==
std
::
string
(
"feature"
)
||
value
==
std
::
string
(
"feature_parallel"
))
{
tree_learner_type
=
"feature"
;
}
else
if
(
value
==
std
::
string
(
"data"
)
||
value
==
std
::
string
(
"data_parallel"
))
{
tree_learner_type
=
"data"
;
}
else
if
(
value
==
std
::
string
(
"voting"
)
||
value
==
std
::
string
(
"voting_parallel"
))
{
tree_learner_type
=
"voting"
;
}
else
{
Log
::
Fatal
(
"Unknown tree learner type %s"
,
value
.
c_str
());
}
}
return
tree_learner_type
;
}
void
OverallConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
// load main config types
GetInt
(
params
,
"num_threads"
,
&
num_threads
);
GetString
(
params
,
"convert_model_language"
,
&
convert_model_language
);
// generate seeds by seed.
if
(
GetInt
(
params
,
"seed"
,
&
seed
))
{
Random
rand
(
seed
);
int
int_max
=
std
::
numeric_limits
<
short
>::
max
();
io_config
.
data_random_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
bagging_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
drop_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
tree_config
.
feature_fraction_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
}
task_type
=
GetTaskType
(
params
);
boosting_type
=
GetBoostingType
(
params
);
metric_types
=
GetMetricType
(
params
);
// sub-config setup
network_config
.
Set
(
params
);
io_config
.
Set
(
params
);
boosting_config
.
Set
(
params
);
objective_type
=
GetObjectiveType
(
params
);
objective_config
.
Set
(
params
);
metric_config
.
Set
(
params
);
// check for conflicts
CheckParamConflict
();
if
(
io_config
.
verbosity
==
1
)
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Info
);
}
else
if
(
io_config
.
verbosity
==
0
)
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Warning
);
}
else
if
(
io_config
.
verbosity
>=
2
)
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Debug
);
}
else
{
LightGBM
::
Log
::
ResetLogLevel
(
LightGBM
::
LogLevel
::
Fatal
);
}
}
void
OverallConfig
::
CheckParamConflict
()
{
// check if objective_type, metric_type, and num_class match
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
)
||
objective_type
==
std
::
string
(
"multiclassova"
));
...
...
@@ -171,13 +215,14 @@ void OverallConfig::CheckParamConflict() {
boosting_config
.
tree_learner_type
=
"serial"
;
}
if
(
boosting_config
.
tree_learner_type
==
std
::
string
(
"serial"
))
{
bool
is_single_tree_learner
=
boosting_config
.
tree_learner_type
==
std
::
string
(
"serial"
);
if
(
is_single_tree_learner
)
{
is_parallel
=
false
;
network_config
.
num_machines
=
1
;
}
if
(
boosting_config
.
tree_learner_type
==
std
::
string
(
"serial"
)
||
boosting_config
.
tree_learner_type
==
std
::
string
(
"feature"
))
{
if
(
is_single_tree_learner
||
boosting_config
.
tree_learner_type
==
std
::
string
(
"feature"
))
{
is_parallel_find_bin
=
false
;
}
else
if
(
boosting_config
.
tree_learner_type
==
std
::
string
(
"data"
)
||
boosting_config
.
tree_learner_type
==
std
::
string
(
"voting"
))
{
...
...
@@ -189,7 +234,6 @@ void OverallConfig::CheckParamConflict() {
// Change pool size to -1 (no limit) when using data parallel to reduce communication costs
boosting_config
.
tree_config
.
histogram_pool_size
=
-
1
;
}
}
}
...
...
@@ -235,21 +279,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt
(
params
,
"pred_early_stop_freq"
,
&
pred_early_stop_freq
);
GetDouble
(
params
,
"pred_early_stop_margin"
,
&
pred_early_stop_margin
);
GetDeviceType
(
params
);
}
void
IOConfig
::
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"device"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"cpu"
))
{
device_type
=
"cpu"
;
}
else
if
(
value
==
std
::
string
(
"gpu"
))
{
device_type
=
"gpu"
;
}
else
{
Log
::
Fatal
(
"Unknown device type %s"
,
value
.
c_str
());
}
}
device_type
=
GetDeviceType
(
params
);
}
void
ObjectiveConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
...
...
@@ -336,7 +366,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetBool
(
params
,
"use_missing"
,
&
use_missing
);
}
void
BoostingConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
GetInt
(
params
,
"num_iterations"
,
&
num_iterations
);
GetDouble
(
params
,
"sigmoid"
,
&
sigmoid
);
...
...
@@ -365,42 +394,12 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool
(
params
,
"boost_from_average"
,
&
boost_from_average
);
CHECK
(
drop_rate
<=
1.0
&&
drop_rate
>=
0.0
);
CHECK
(
skip_drop
<=
1.0
&&
skip_drop
>=
0.0
);
GetDeviceType
(
params
);
GetTreeLearnerType
(
params
);
device_type
=
GetDeviceType
(
params
);
tree_learner_type
=
GetTreeLearnerType
(
params
);
tree_config
.
Set
(
params
);
}
void
BoostingConfig
::
GetTreeLearnerType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"tree_learner"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"serial"
))
{
tree_learner_type
=
"serial"
;
}
else
if
(
value
==
std
::
string
(
"feature"
)
||
value
==
std
::
string
(
"feature_parallel"
))
{
tree_learner_type
=
"feature"
;
}
else
if
(
value
==
std
::
string
(
"data"
)
||
value
==
std
::
string
(
"data_parallel"
))
{
tree_learner_type
=
"data"
;
}
else
if
(
value
==
std
::
string
(
"voting"
)
||
value
==
std
::
string
(
"voting_parallel"
))
{
tree_learner_type
=
"voting"
;
}
else
{
Log
::
Fatal
(
"Unknown tree learner type %s"
,
value
.
c_str
());
}
}
}
void
BoostingConfig
::
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
std
::
string
value
;
if
(
GetString
(
params
,
"device"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"cpu"
))
{
device_type
=
"cpu"
;
}
else
if
(
value
==
std
::
string
(
"gpu"
))
{
device_type
=
"gpu"
;
}
else
{
Log
::
Fatal
(
"Unknown device type %s"
,
value
.
c_str
());
}
}
}
void
NetworkConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
GetInt
(
params
,
"num_machines"
,
&
num_machines
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment