Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
9c57793e
"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "2dfb9a40478b965db8325baa21a63d9281f96b7c"
Commit
9c57793e
authored
Sep 26, 2017
by
wxchan
Committed by
Guolin Ke
Sep 26, 2017
Browse files
refine set params (#933)
parent
e66a8a3c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
42 deletions
+30
-42
src/io/config.cpp
src/io/config.cpp
+30
-42
No files found.
src/io/config.cpp
View file @
9c57793e
...
@@ -32,42 +32,37 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
...
@@ -32,42 +32,37 @@ std::unordered_map<std::string, std::string> ConfigBase::Str2Map(const char* par
return
params
;
return
params
;
}
}
std
::
string
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
std
::
string
*
boosting_type
)
{
std
::
string
value
;
std
::
string
value
;
std
::
string
boosting_type
=
kDefaultBoostingType
;
if
(
ConfigBase
::
GetString
(
params
,
"boosting_type"
,
&
value
))
{
if
(
ConfigBase
::
GetString
(
params
,
"boosting_type"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"gbdt"
)
||
value
==
std
::
string
(
"gbrt"
))
{
if
(
value
==
std
::
string
(
"gbdt"
)
||
value
==
std
::
string
(
"gbrt"
))
{
boosting_type
=
"gbdt"
;
*
boosting_type
=
"gbdt"
;
}
else
if
(
value
==
std
::
string
(
"dart"
))
{
}
else
if
(
value
==
std
::
string
(
"dart"
))
{
boosting_type
=
"dart"
;
*
boosting_type
=
"dart"
;
}
else
if
(
value
==
std
::
string
(
"goss"
))
{
}
else
if
(
value
==
std
::
string
(
"goss"
))
{
boosting_type
=
"goss"
;
*
boosting_type
=
"goss"
;
}
else
if
(
value
==
std
::
string
(
"rf"
)
||
value
==
std
::
string
(
"randomforest"
))
{
}
else
if
(
value
==
std
::
string
(
"rf"
)
||
value
==
std
::
string
(
"randomforest"
))
{
boosting_type
=
"rf"
;
*
boosting_type
=
"rf"
;
}
else
{
}
else
{
Log
::
Fatal
(
"Unknown boosting type %s"
,
value
.
c_str
());
Log
::
Fatal
(
"Unknown boosting type %s"
,
value
.
c_str
());
}
}
}
}
return
boosting_type
;
}
}
std
::
string
GetObjectiveType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
GetObjectiveType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
std
::
string
*
objective_type
)
{
std
::
string
value
;
std
::
string
value
;
std
::
string
objective_type
=
kDefaultObjectiveType
;
if
(
ConfigBase
::
GetString
(
params
,
"objective"
,
&
value
))
{
if
(
ConfigBase
::
GetString
(
params
,
"objective"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
objective_type
=
value
;
*
objective_type
=
value
;
}
}
return
objective_type
;
}
}
std
::
vector
<
std
::
string
>
GetMetricType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
GetMetricType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
std
::
vector
<
std
::
string
>*
metric_types
)
{
std
::
string
value
;
std
::
string
value
;
std
::
vector
<
std
::
string
>
metric_types
;
if
(
ConfigBase
::
GetString
(
params
,
"metric"
,
&
value
))
{
if
(
ConfigBase
::
GetString
(
params
,
"metric"
,
&
value
))
{
// clear old metrics
// clear old metrics
metric_types
.
clear
();
metric_types
->
clear
();
// to lower
// to lower
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
// split
// split
...
@@ -81,66 +76,59 @@ std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std
...
@@ -81,66 +76,59 @@ std::vector<std::string> GetMetricType(const std::unordered_map<std::string, std
}
}
}
}
for
(
auto
&
metric
:
metric_sets
)
{
for
(
auto
&
metric
:
metric_sets
)
{
metric_types
.
push_back
(
metric
);
metric_types
->
push_back
(
metric
);
}
}
metric_types
.
shrink_to_fit
();
metric_types
->
shrink_to_fit
();
}
}
return
metric_types
;
}
}
TaskType
GetTaskType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
GetTaskType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
TaskType
*
task_type
)
{
std
::
string
value
;
std
::
string
value
;
TaskType
task_type
=
TaskType
::
kTrain
;
if
(
ConfigBase
::
GetString
(
params
,
"task"
,
&
value
))
{
if
(
ConfigBase
::
GetString
(
params
,
"task"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"train"
)
||
value
==
std
::
string
(
"training"
))
{
if
(
value
==
std
::
string
(
"train"
)
||
value
==
std
::
string
(
"training"
))
{
task_type
=
TaskType
::
kTrain
;
*
task_type
=
TaskType
::
kTrain
;
}
else
if
(
value
==
std
::
string
(
"predict"
)
||
value
==
std
::
string
(
"prediction"
)
}
else
if
(
value
==
std
::
string
(
"predict"
)
||
value
==
std
::
string
(
"prediction"
)
||
value
==
std
::
string
(
"test"
))
{
||
value
==
std
::
string
(
"test"
))
{
task_type
=
TaskType
::
kPredict
;
*
task_type
=
TaskType
::
kPredict
;
}
else
if
(
value
==
std
::
string
(
"convert_model"
))
{
}
else
if
(
value
==
std
::
string
(
"convert_model"
))
{
task_type
=
TaskType
::
kConvertModel
;
*
task_type
=
TaskType
::
kConvertModel
;
}
else
{
}
else
{
Log
::
Fatal
(
"Unknown task type %s"
,
value
.
c_str
());
Log
::
Fatal
(
"Unknown task type %s"
,
value
.
c_str
());
}
}
}
}
return
task_type
;
}
}
std
::
string
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
GetDeviceType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
std
::
string
*
device_type
)
{
std
::
string
value
;
std
::
string
value
;
std
::
string
device_type
=
kDefaultDevice
;
if
(
ConfigBase
::
GetString
(
params
,
"device"
,
&
value
))
{
if
(
ConfigBase
::
GetString
(
params
,
"device"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"cpu"
))
{
if
(
value
==
std
::
string
(
"cpu"
))
{
device_type
=
"cpu"
;
*
device_type
=
"cpu"
;
}
else
if
(
value
==
std
::
string
(
"gpu"
))
{
}
else
if
(
value
==
std
::
string
(
"gpu"
))
{
device_type
=
"gpu"
;
*
device_type
=
"gpu"
;
}
else
{
}
else
{
Log
::
Fatal
(
"Unknown device type %s"
,
value
.
c_str
());
Log
::
Fatal
(
"Unknown device type %s"
,
value
.
c_str
());
}
}
}
}
return
device_type
;
}
}
std
::
string
GetTreeLearnerType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
GetTreeLearnerType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
std
::
string
*
tree_learner_type
)
{
std
::
string
value
;
std
::
string
value
;
std
::
string
tree_learner_type
=
kDefaultTreeLearnerType
;
if
(
ConfigBase
::
GetString
(
params
,
"tree_learner"
,
&
value
))
{
if
(
ConfigBase
::
GetString
(
params
,
"tree_learner"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"serial"
))
{
if
(
value
==
std
::
string
(
"serial"
))
{
tree_learner_type
=
"serial"
;
*
tree_learner_type
=
"serial"
;
}
else
if
(
value
==
std
::
string
(
"feature"
)
||
value
==
std
::
string
(
"feature_parallel"
))
{
}
else
if
(
value
==
std
::
string
(
"feature"
)
||
value
==
std
::
string
(
"feature_parallel"
))
{
tree_learner_type
=
"feature"
;
*
tree_learner_type
=
"feature"
;
}
else
if
(
value
==
std
::
string
(
"data"
)
||
value
==
std
::
string
(
"data_parallel"
))
{
}
else
if
(
value
==
std
::
string
(
"data"
)
||
value
==
std
::
string
(
"data_parallel"
))
{
tree_learner_type
=
"data"
;
*
tree_learner_type
=
"data"
;
}
else
if
(
value
==
std
::
string
(
"voting"
)
||
value
==
std
::
string
(
"voting_parallel"
))
{
}
else
if
(
value
==
std
::
string
(
"voting"
)
||
value
==
std
::
string
(
"voting_parallel"
))
{
tree_learner_type
=
"voting"
;
*
tree_learner_type
=
"voting"
;
}
else
{
}
else
{
Log
::
Fatal
(
"Unknown tree learner type %s"
,
value
.
c_str
());
Log
::
Fatal
(
"Unknown tree learner type %s"
,
value
.
c_str
());
}
}
}
}
return
tree_learner_type
;
}
}
void
OverallConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
OverallConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
...
@@ -157,17 +145,17 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
...
@@ -157,17 +145,17 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config
.
drop_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
drop_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
tree_config
.
feature_fraction_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
boosting_config
.
tree_config
.
feature_fraction_seed
=
static_cast
<
int
>
(
rand
.
NextShort
(
0
,
int_max
));
}
}
task_type
=
GetTaskType
(
params
);
GetTaskType
(
params
,
&
task_type
);
boosting_type
=
GetBoostingType
(
params
);
GetBoostingType
(
params
,
&
boosting_type
);
metric_types
=
GetMetricType
(
params
);
GetMetricType
(
params
,
&
metric_types
);
// sub-config setup
// sub-config setup
network_config
.
Set
(
params
);
network_config
.
Set
(
params
);
io_config
.
Set
(
params
);
io_config
.
Set
(
params
);
boosting_config
.
Set
(
params
);
boosting_config
.
Set
(
params
);
objective_type
=
GetObjectiveType
(
params
);
GetObjectiveType
(
params
,
&
objective_type
);
objective_config
.
Set
(
params
);
objective_config
.
Set
(
params
);
metric_config
.
Set
(
params
);
metric_config
.
Set
(
params
);
...
@@ -298,7 +286,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...
@@ -298,7 +286,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble
(
params
,
"pred_early_stop_margin"
,
&
pred_early_stop_margin
);
GetDouble
(
params
,
"pred_early_stop_margin"
,
&
pred_early_stop_margin
);
GetBool
(
params
,
"use_missing"
,
&
use_missing
);
GetBool
(
params
,
"use_missing"
,
&
use_missing
);
GetBool
(
params
,
"zero_as_missing"
,
&
zero_as_missing
);
GetBool
(
params
,
"zero_as_missing"
,
&
zero_as_missing
);
device_type
=
GetDeviceType
(
params
);
GetDeviceType
(
params
,
&
device_type
);
}
}
void
ObjectiveConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
ObjectiveConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
...
@@ -413,8 +401,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
...
@@ -413,8 +401,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool
(
params
,
"boost_from_average"
,
&
boost_from_average
);
GetBool
(
params
,
"boost_from_average"
,
&
boost_from_average
);
CHECK
(
drop_rate
<=
1.0
&&
drop_rate
>=
0.0
);
CHECK
(
drop_rate
<=
1.0
&&
drop_rate
>=
0.0
);
CHECK
(
skip_drop
<=
1.0
&&
skip_drop
>=
0.0
);
CHECK
(
skip_drop
<=
1.0
&&
skip_drop
>=
0.0
);
device_type
=
GetDeviceType
(
params
);
GetDeviceType
(
params
,
&
device_type
);
tree_learner_type
=
GetTreeLearnerType
(
params
);
GetTreeLearnerType
(
params
,
&
tree_learner_type
);
tree_config
.
Set
(
params
);
tree_config
.
Set
(
params
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment