Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
42710827
Commit
42710827
authored
Dec 29, 2017
by
Guolin Ke
Browse files
fix `max_drop`. add many checks for parameters.
parent
0af44ac8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
7 deletions
+31
-7
include/LightGBM/config.h
include/LightGBM/config.h
+0
-1
src/boosting/dart.hpp
src/boosting/dart.hpp
+6
-0
src/io/config.cpp
src/io/config.cpp
+25
-6
No files found.
include/LightGBM/config.h
View file @
42710827
...
@@ -239,7 +239,6 @@ public:
...
@@ -239,7 +239,6 @@ public:
struct
BoostingConfig
:
public
ConfigBase
{
struct
BoostingConfig
:
public
ConfigBase
{
public:
public:
virtual
~
BoostingConfig
()
{}
virtual
~
BoostingConfig
()
{}
double
sigmoid
=
1.0
f
;
int
output_freq
=
1
;
int
output_freq
=
1
;
bool
is_provide_training_metric
=
false
;
bool
is_provide_training_metric
=
false
;
int
num_iterations
=
100
;
int
num_iterations
=
100
;
...
...
src/boosting/dart.hpp
View file @
42710827
...
@@ -96,6 +96,9 @@ private:
...
@@ -96,6 +96,9 @@ private:
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
if
(
random_for_drop_
.
NextFloat
()
<
drop_rate
*
tree_weight_
[
i
]
*
inv_average_weight
)
{
if
(
random_for_drop_
.
NextFloat
()
<
drop_rate
*
tree_weight_
[
i
]
*
inv_average_weight
)
{
drop_index_
.
push_back
(
num_init_iteration_
+
i
);
drop_index_
.
push_back
(
num_init_iteration_
+
i
);
if
(
drop_index_
.
size
()
>=
static_cast
<
size_t
>
(
gbdt_config_
->
max_drop
))
{
break
;
}
}
}
}
}
}
else
{
}
else
{
...
@@ -105,6 +108,9 @@ private:
...
@@ -105,6 +108,9 @@ private:
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
if
(
random_for_drop_
.
NextFloat
()
<
drop_rate
)
{
if
(
random_for_drop_
.
NextFloat
()
<
drop_rate
)
{
drop_index_
.
push_back
(
num_init_iteration_
+
i
);
drop_index_
.
push_back
(
num_init_iteration_
+
i
);
if
(
drop_index_
.
size
()
>=
static_cast
<
size_t
>
(
gbdt_config_
->
max_drop
))
{
break
;
}
}
}
}
}
}
}
...
...
src/io/config.cpp
View file @
42710827
...
@@ -251,12 +251,14 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...
@@ -251,12 +251,14 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt
(
params
,
"max_bin"
,
&
max_bin
);
GetInt
(
params
,
"max_bin"
,
&
max_bin
);
CHECK
(
max_bin
>
0
);
CHECK
(
max_bin
>
0
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
CHECK
(
num_class
>
0
);
GetInt
(
params
,
"data_random_seed"
,
&
data_random_seed
);
GetInt
(
params
,
"data_random_seed"
,
&
data_random_seed
);
GetString
(
params
,
"data"
,
&
data_filename
);
GetString
(
params
,
"data"
,
&
data_filename
);
GetString
(
params
,
"init_score_file"
,
&
initscore_filename
);
GetString
(
params
,
"init_score_file"
,
&
initscore_filename
);
GetInt
(
params
,
"verbose"
,
&
verbosity
);
GetInt
(
params
,
"verbose"
,
&
verbosity
);
GetInt
(
params
,
"num_iteration_predict"
,
&
num_iteration_predict
);
GetInt
(
params
,
"num_iteration_predict"
,
&
num_iteration_predict
);
GetInt
(
params
,
"bin_construct_sample_cnt"
,
&
bin_construct_sample_cnt
);
GetInt
(
params
,
"bin_construct_sample_cnt"
,
&
bin_construct_sample_cnt
);
CHECK
(
bin_construct_sample_cnt
>
0
);
GetBool
(
params
,
"is_pre_partition"
,
&
is_pre_partition
);
GetBool
(
params
,
"is_pre_partition"
,
&
is_pre_partition
);
GetBool
(
params
,
"is_enable_sparse"
,
&
is_enable_sparse
);
GetBool
(
params
,
"is_enable_sparse"
,
&
is_enable_sparse
);
GetDouble
(
params
,
"sparse_threshold"
,
&
sparse_threshold
);
GetDouble
(
params
,
"sparse_threshold"
,
&
sparse_threshold
);
...
@@ -290,9 +292,10 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...
@@ -290,9 +292,10 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt
(
params
,
"min_data_in_leaf"
,
&
min_data_in_leaf
);
GetInt
(
params
,
"min_data_in_leaf"
,
&
min_data_in_leaf
);
GetInt
(
params
,
"min_data_in_bin"
,
&
min_data_in_bin
);
GetInt
(
params
,
"min_data_in_bin"
,
&
min_data_in_bin
);
CHECK
(
min_data_in_bin
>
0
);
CHECK
(
min_data_in_bin
>
0
);
CHECK
(
min_data_in_leaf
>=
0
);
GetDouble
(
params
,
"max_conflict_rate"
,
&
max_conflict_rate
);
GetDouble
(
params
,
"max_conflict_rate"
,
&
max_conflict_rate
);
CHECK
(
max_conflict_rate
>=
0
);
GetBool
(
params
,
"enable_bundle"
,
&
enable_bundle
);
GetBool
(
params
,
"enable_bundle"
,
&
enable_bundle
);
GetBool
(
params
,
"pred_early_stop"
,
&
pred_early_stop
);
GetBool
(
params
,
"pred_early_stop"
,
&
pred_early_stop
);
GetInt
(
params
,
"pred_early_stop_freq"
,
&
pred_early_stop_freq
);
GetInt
(
params
,
"pred_early_stop_freq"
,
&
pred_early_stop_freq
);
GetDouble
(
params
,
"pred_early_stop_margin"
,
&
pred_early_stop_margin
);
GetDouble
(
params
,
"pred_early_stop_margin"
,
&
pred_early_stop_margin
);
...
@@ -304,15 +307,21 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...
@@ -304,15 +307,21 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
void
ObjectiveConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
ObjectiveConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
GetBool
(
params
,
"is_unbalance"
,
&
is_unbalance
);
GetBool
(
params
,
"is_unbalance"
,
&
is_unbalance
);
GetDouble
(
params
,
"sigmoid"
,
&
sigmoid
);
GetDouble
(
params
,
"sigmoid"
,
&
sigmoid
);
CHECK
(
sigmoid
>
0
);
GetDouble
(
params
,
"fair_c"
,
&
fair_c
);
GetDouble
(
params
,
"fair_c"
,
&
fair_c
);
CHECK
(
fair_c
>
0
);
GetDouble
(
params
,
"gaussian_eta"
,
&
gaussian_eta
);
GetDouble
(
params
,
"gaussian_eta"
,
&
gaussian_eta
);
CHECK
(
gaussian_eta
>
0
);
GetDouble
(
params
,
"poisson_max_delta_step"
,
&
poisson_max_delta_step
);
GetDouble
(
params
,
"poisson_max_delta_step"
,
&
poisson_max_delta_step
);
CHECK
(
poisson_max_delta_step
>
0
);
GetInt
(
params
,
"max_position"
,
&
max_position
);
GetInt
(
params
,
"max_position"
,
&
max_position
);
CHECK
(
max_position
>
0
);
CHECK
(
max_position
>
0
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
CHECK
(
num_class
>
=
1
);
CHECK
(
num_class
>
0
);
GetDouble
(
params
,
"scale_pos_weight"
,
&
scale_pos_weight
);
GetDouble
(
params
,
"scale_pos_weight"
,
&
scale_pos_weight
);
CHECK
(
scale_pos_weight
>
0
);
GetDouble
(
params
,
"alpha"
,
&
alpha
);
GetDouble
(
params
,
"alpha"
,
&
alpha
);
CHECK
(
alpha
>
0
&&
alpha
<
1
);
GetBool
(
params
,
"reg_sqrt"
,
&
reg_sqrt
);
GetBool
(
params
,
"reg_sqrt"
,
&
reg_sqrt
);
std
::
string
tmp_str
=
""
;
std
::
string
tmp_str
=
""
;
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
...
@@ -331,9 +340,13 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
...
@@ -331,9 +340,13 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void
MetricConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
MetricConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
GetDouble
(
params
,
"sigmoid"
,
&
sigmoid
);
GetDouble
(
params
,
"sigmoid"
,
&
sigmoid
);
CHECK
(
sigmoid
>
0
);
GetDouble
(
params
,
"fair_c"
,
&
fair_c
);
GetDouble
(
params
,
"fair_c"
,
&
fair_c
);
CHECK
(
fair_c
>
0
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
CHECK
(
num_class
>
0
);
GetDouble
(
params
,
"alpha"
,
&
alpha
);
GetDouble
(
params
,
"alpha"
,
&
alpha
);
CHECK
(
alpha
>
0
&&
alpha
<
1
);
std
::
string
tmp_str
=
""
;
std
::
string
tmp_str
=
""
;
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
label_gain
=
Common
::
StringToArray
<
double
>
(
tmp_str
,
','
);
label_gain
=
Common
::
StringToArray
<
double
>
(
tmp_str
,
','
);
...
@@ -365,7 +378,8 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
...
@@ -365,7 +378,8 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
void
TreeConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
TreeConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
GetInt
(
params
,
"min_data_in_leaf"
,
&
min_data_in_leaf
);
GetInt
(
params
,
"min_data_in_leaf"
,
&
min_data_in_leaf
);
GetDouble
(
params
,
"min_sum_hessian_in_leaf"
,
&
min_sum_hessian_in_leaf
);
GetDouble
(
params
,
"min_sum_hessian_in_leaf"
,
&
min_sum_hessian_in_leaf
);
CHECK
(
min_sum_hessian_in_leaf
>
0
||
min_data_in_leaf
>
0
);
CHECK
(
min_data_in_leaf
>
0
);
CHECK
(
min_sum_hessian_in_leaf
>=
0
);
GetDouble
(
params
,
"lambda_l1"
,
&
lambda_l1
);
GetDouble
(
params
,
"lambda_l1"
,
&
lambda_l1
);
CHECK
(
lambda_l1
>=
0.0
f
);
CHECK
(
lambda_l1
>=
0.0
f
);
GetDouble
(
params
,
"lambda_l2"
,
&
lambda_l2
);
GetDouble
(
params
,
"lambda_l2"
,
&
lambda_l2
);
...
@@ -380,6 +394,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
...
@@ -380,6 +394,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble
(
params
,
"histogram_pool_size"
,
&
histogram_pool_size
);
GetDouble
(
params
,
"histogram_pool_size"
,
&
histogram_pool_size
);
GetInt
(
params
,
"max_depth"
,
&
max_depth
);
GetInt
(
params
,
"max_depth"
,
&
max_depth
);
GetInt
(
params
,
"top_k"
,
&
top_k
);
GetInt
(
params
,
"top_k"
,
&
top_k
);
CHECK
(
top_k
>
0
);
GetInt
(
params
,
"gpu_platform_id"
,
&
gpu_platform_id
);
GetInt
(
params
,
"gpu_platform_id"
,
&
gpu_platform_id
);
GetInt
(
params
,
"gpu_device_id"
,
&
gpu_device_id
);
GetInt
(
params
,
"gpu_device_id"
,
&
gpu_device_id
);
GetBool
(
params
,
"gpu_use_dp"
,
&
gpu_use_dp
);
GetBool
(
params
,
"gpu_use_dp"
,
&
gpu_use_dp
);
...
@@ -397,7 +412,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
...
@@ -397,7 +412,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
void
BoostingConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
BoostingConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
GetInt
(
params
,
"num_iterations"
,
&
num_iterations
);
GetInt
(
params
,
"num_iterations"
,
&
num_iterations
);
GetDouble
(
params
,
"sigmoid"
,
&
sigmoid
);
CHECK
(
num_iterations
>=
0
);
CHECK
(
num_iterations
>=
0
);
GetInt
(
params
,
"bagging_seed"
,
&
bagging_seed
);
GetInt
(
params
,
"bagging_seed"
,
&
bagging_seed
);
GetInt
(
params
,
"bagging_freq"
,
&
bagging_freq
);
GetInt
(
params
,
"bagging_freq"
,
&
bagging_freq
);
...
@@ -412,17 +426,22 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
...
@@ -412,17 +426,22 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK
(
output_freq
>=
0
);
CHECK
(
output_freq
>=
0
);
GetBool
(
params
,
"is_training_metric"
,
&
is_provide_training_metric
);
GetBool
(
params
,
"is_training_metric"
,
&
is_provide_training_metric
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
CHECK
(
num_class
>
0
);
GetInt
(
params
,
"drop_seed"
,
&
drop_seed
);
GetInt
(
params
,
"drop_seed"
,
&
drop_seed
);
GetDouble
(
params
,
"drop_rate"
,
&
drop_rate
);
GetDouble
(
params
,
"drop_rate"
,
&
drop_rate
);
GetDouble
(
params
,
"skip_drop"
,
&
skip_drop
);
GetDouble
(
params
,
"skip_drop"
,
&
skip_drop
);
CHECK
(
drop_rate
<=
1.0
&&
drop_rate
>=
0.0
);
CHECK
(
skip_drop
<=
1.0
&&
skip_drop
>=
0.0
);
GetInt
(
params
,
"max_drop"
,
&
max_drop
);
GetInt
(
params
,
"max_drop"
,
&
max_drop
);
CHECK
(
max_drop
>
0
);
GetBool
(
params
,
"xgboost_dart_mode"
,
&
xgboost_dart_mode
);
GetBool
(
params
,
"xgboost_dart_mode"
,
&
xgboost_dart_mode
);
GetBool
(
params
,
"uniform_drop"
,
&
uniform_drop
);
GetBool
(
params
,
"uniform_drop"
,
&
uniform_drop
);
GetDouble
(
params
,
"top_rate"
,
&
top_rate
);
GetDouble
(
params
,
"top_rate"
,
&
top_rate
);
GetDouble
(
params
,
"other_rate"
,
&
other_rate
);
GetDouble
(
params
,
"other_rate"
,
&
other_rate
);
CHECK
(
top_rate
>
0
);
CHECK
(
top_rate
>
0
);
CHECK
(
top_rate
+
top_rate
<=
1.0
);
GetBool
(
params
,
"boost_from_average"
,
&
boost_from_average
);
GetBool
(
params
,
"boost_from_average"
,
&
boost_from_average
);
CHECK
(
drop_rate
<=
1.0
&&
drop_rate
>=
0.0
);
CHECK
(
skip_drop
<=
1.0
&&
skip_drop
>=
0.0
);
GetDeviceType
(
params
,
&
device_type
);
GetDeviceType
(
params
,
&
device_type
);
GetTreeLearnerType
(
params
,
&
tree_learner_type
);
GetTreeLearnerType
(
params
,
&
tree_learner_type
);
tree_config
.
Set
(
params
);
tree_config
.
Set
(
params
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment