Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
1e7ccbbb
"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "6d34fb86359f955a12369ef9a2803f0171ea0a3b"
Commit
1e7ccbbb
authored
Jul 04, 2017
by
Guolin Ke
Browse files
clean code for Boosting::ResetTrainingData.
parent
a98b23d2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
204 additions
and
153 deletions
+204
-153
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+4
-8
include/LightGBM/config.h
include/LightGBM/config.h
+1
-1
src/boosting/dart.hpp
src/boosting/dart.hpp
+0
-4
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+160
-113
src/boosting/gbdt.h
src/boosting/gbdt.h
+4
-7
src/boosting/goss.hpp
src/boosting/goss.hpp
+17
-10
src/c_api.cpp
src/c_api.cpp
+18
-10
No files found.
include/LightGBM/boosting.h
View file @
1e7ccbbb
...
@@ -43,14 +43,10 @@ public:
...
@@ -43,14 +43,10 @@ public:
*/
*/
virtual
void
MergeFrom
(
const
Boosting
*
other
)
=
0
;
virtual
void
MergeFrom
(
const
Boosting
*
other
)
=
0
;
/*!
virtual
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
* \brief Reset training data for current boosting
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
* \param config Configs for boosting
* \param train_data Training data
virtual
void
ResetConfig
(
const
BoostingConfig
*
config
)
=
0
;
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
virtual
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
/*!
* \brief Add a validation data
* \brief Add a validation data
...
...
include/LightGBM/config.h
View file @
1e7ccbbb
...
@@ -91,7 +91,7 @@ public:
...
@@ -91,7 +91,7 @@ public:
int
data_random_seed
=
1
;
int
data_random_seed
=
1
;
std
::
string
data_filename
=
""
;
std
::
string
data_filename
=
""
;
std
::
vector
<
std
::
string
>
valid_data_filenames
;
std
::
vector
<
std
::
string
>
valid_data_filenames
;
int
snapshot_freq
=
1
00
;
int
snapshot_freq
=
-
1
;
std
::
string
output_model
=
"LightGBM_model.txt"
;
std
::
string
output_model
=
"LightGBM_model.txt"
;
std
::
string
output_result
=
"LightGBM_predict_result.txt"
;
std
::
string
output_result
=
"LightGBM_predict_result.txt"
;
std
::
string
convert_model
=
"gbdt_prediction.cpp"
;
std
::
string
convert_model
=
"gbdt_prediction.cpp"
;
...
...
src/boosting/dart.hpp
View file @
1e7ccbbb
...
@@ -39,10 +39,6 @@ public:
...
@@ -39,10 +39,6 @@ public:
sum_weight_
=
0.0
f
;
sum_weight_
=
0.0
f
;
}
}
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
GBDT
::
ResetTrainingData
(
config
,
train_data
,
objective_function
,
training_metrics
);
}
/*!
/*!
* \brief one training iteration
* \brief one training iteration
*/
*/
...
...
src/boosting/gbdt.cpp
View file @
1e7ccbbb
...
@@ -64,24 +64,14 @@ GBDT::~GBDT() {
...
@@ -64,24 +64,14 @@ GBDT::~GBDT() {
void
GBDT
::
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
void
GBDT
::
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
train_data_
=
train_data
;
iter_
=
0
;
iter_
=
0
;
num_iteration_for_pred_
=
0
;
num_iteration_for_pred_
=
0
;
max_feature_idx_
=
0
;
max_feature_idx_
=
0
;
num_class_
=
config
->
num_class
;
num_class_
=
config
->
num_class
;
train_data_
=
nullptr
;
gbdt_config_
=
std
::
unique_ptr
<
BoostingConfig
>
(
new
BoostingConfig
(
*
config
));
gbdt_config_
=
nullptr
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
tree_learner_
=
nullptr
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
ResetTrainingData
(
config
,
train_data
,
objective_function
,
training_metrics
);
}
void
GBDT
::
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
auto
new_config
=
std
::
unique_ptr
<
BoostingConfig
>
(
new
BoostingConfig
(
*
config
));
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
}
early_stopping_round_
=
new_config
->
early_stopping_round
;
shrinkage_rate_
=
new_config
->
learning_rate
;
objective_function_
=
objective_function
;
objective_function_
=
objective_function
;
num_tree_per_iteration_
=
num_class_
;
num_tree_per_iteration_
=
num_class_
;
...
@@ -92,12 +82,10 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -92,12 +82,10 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
is_constant_hessian_
=
false
;
is_constant_hessian_
=
false
;
}
}
if
(
train_data_
!=
train_data
&&
train_data
!=
nullptr
)
{
tree_learner_
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
device_type
,
&
gbdt_config_
->
tree_config
));
if
(
tree_learner_
==
nullptr
)
{
tree_learner_
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
new_config
->
tree_learner_type
,
new_config
->
device_type
,
&
new_config
->
tree_config
));
}
// init tree learner
// init tree learner
tree_learner_
->
Init
(
train_data
,
is_constant_hessian_
);
tree_learner_
->
Init
(
train_data
_
,
is_constant_hessian_
);
// push training metrics
// push training metrics
training_metrics_
.
clear
();
training_metrics_
.
clear
();
...
@@ -105,17 +93,10 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -105,17 +93,10 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
training_metrics_
.
push_back
(
metric
);
training_metrics_
.
push_back
(
metric
);
}
}
training_metrics_
.
shrink_to_fit
();
training_metrics_
.
shrink_to_fit
();
// not same training data, need reset score and others
// create score tracker
train_score_updater_
.
reset
(
new
ScoreUpdater
(
train_data_
,
num_tree_per_iteration_
));
train_score_updater_
.
reset
(
new
ScoreUpdater
(
train_data
,
num_tree_per_iteration_
));
// update score
num_data_
=
train_data_
->
num_data
();
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
cur_tree_id
=
0
;
cur_tree_id
<
num_tree_per_iteration_
;
++
cur_tree_id
)
{
auto
curr_tree
=
(
i
+
num_init_iteration_
)
*
num_tree_per_iteration_
+
cur_tree_id
;
train_score_updater_
->
AddScore
(
models_
[
curr_tree
].
get
(),
cur_tree_id
);
}
}
num_data_
=
train_data
->
num_data
();
// create buffer for gradients and hessians
// create buffer for gradients and hessians
if
(
objective_function_
!=
nullptr
)
{
if
(
objective_function_
!=
nullptr
)
{
size_t
total_size
=
static_cast
<
size_t
>
(
num_data_
)
*
num_tree_per_iteration_
;
size_t
total_size
=
static_cast
<
size_t
>
(
num_data_
)
*
num_tree_per_iteration_
;
...
@@ -123,56 +104,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -123,56 +104,17 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
hessians_
.
resize
(
total_size
);
hessians_
.
resize
(
total_size
);
}
}
// get max feature index
// get max feature index
max_feature_idx_
=
train_data
->
num_total_features
()
-
1
;
max_feature_idx_
=
train_data
_
->
num_total_features
()
-
1
;
// get label index
// get label index
label_idx_
=
train_data
->
label_idx
();
label_idx_
=
train_data
_
->
label_idx
();
// get feature names
// get feature names
feature_names_
=
train_data
->
feature_names
();
feature_names_
=
train_data_
->
feature_names
();
feature_infos_
=
train_data_
->
feature_infos
();
feature_infos_
=
train_data
->
feature_infos
();
}
if
((
train_data_
!=
train_data
&&
train_data
!=
nullptr
)
||
(
gbdt_config_
!=
nullptr
&&
gbdt_config_
->
bagging_fraction
!=
new_config
->
bagging_fraction
))
{
// if need bagging, create buffer
// if need bagging, create buffer
if
(
new_config
->
bagging_fraction
<
1.0
&&
new_config
->
bagging_freq
>
0
)
{
ResetBaggingConfig
(
gbdt_config_
.
get
());
bag_data_cnt_
=
static_cast
<
data_size_t
>
(
new_config
->
bagging_fraction
*
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
tmp_indices_
.
resize
(
num_data_
);
offsets_buf_
.
resize
(
num_threads_
);
left_cnts_buf_
.
resize
(
num_threads_
);
right_cnts_buf_
.
resize
(
num_threads_
);
left_write_pos_buf_
.
resize
(
num_threads_
);
right_write_pos_buf_
.
resize
(
num_threads_
);
double
average_bag_rate
=
new_config
->
bagging_fraction
/
new_config
->
bagging_freq
;
int
sparse_group
=
0
;
for
(
int
i
=
0
;
i
<
train_data
->
num_feature_groups
();
++
i
)
{
if
(
train_data
->
FeatureGroupIsSparse
(
i
))
{
++
sparse_group
;
}
}
is_use_subset_
=
false
;
const
int
group_threshold_usesubset
=
100
;
const
int
sparse_group_threshold_usesubset
=
train_data
->
num_feature_groups
()
/
4
;
if
(
average_bag_rate
<=
0.5
&&
(
train_data
->
num_feature_groups
()
<
group_threshold_usesubset
||
sparse_group
<
sparse_group_threshold_usesubset
))
{
tmp_subset_
.
reset
(
new
Dataset
(
bag_data_cnt_
));
tmp_subset_
->
CopyFeatureMapperFrom
(
train_data
);
is_use_subset_
=
true
;
Log
::
Debug
(
"use subset for bagging"
);
}
}
else
{
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
tmp_indices_
.
clear
();
is_use_subset_
=
false
;
}
}
train_data_
=
train_data
;
if
(
train_data_
!=
nullptr
)
{
// reset config for tree learner
// reset config for tree learner
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
class_need_train_
=
std
::
vector
<
bool
>
(
num_tree_per_iteration_
,
true
);
class_need_train_
=
std
::
vector
<
bool
>
(
num_tree_per_iteration_
,
true
);
if
(
objective_function_
!=
nullptr
&&
objective_function_
->
SkipEmptyClass
())
{
if
(
objective_function_
!=
nullptr
&&
objective_function_
->
SkipEmptyClass
())
{
CHECK
(
num_tree_per_iteration_
==
num_class_
);
CHECK
(
num_tree_per_iteration_
==
num_class_
);
...
@@ -213,10 +155,115 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -213,10 +155,115 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
}
}
}
}
}
}
}
void
GBDT
::
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
if
(
train_data
!=
train_data_
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
}
objective_function_
=
objective_function
;
num_tree_per_iteration_
=
num_class_
;
if
(
objective_function_
!=
nullptr
)
{
is_constant_hessian_
=
objective_function_
->
IsConstantHessian
();
num_tree_per_iteration_
=
objective_function_
->
NumTreePerIteration
();
}
else
{
is_constant_hessian_
=
false
;
}
// push training metrics
training_metrics_
.
clear
();
for
(
const
auto
&
metric
:
training_metrics
)
{
training_metrics_
.
push_back
(
metric
);
}
training_metrics_
.
shrink_to_fit
();
if
(
train_data
!=
train_data_
)
{
train_data_
=
train_data
;
// not same training data, need reset score and others
// create score tracker
train_score_updater_
.
reset
(
new
ScoreUpdater
(
train_data_
,
num_tree_per_iteration_
));
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
cur_tree_id
=
0
;
cur_tree_id
<
num_tree_per_iteration_
;
++
cur_tree_id
)
{
auto
curr_tree
=
(
i
+
num_init_iteration_
)
*
num_tree_per_iteration_
+
cur_tree_id
;
train_score_updater_
->
AddScore
(
models_
[
curr_tree
].
get
(),
cur_tree_id
);
}
}
num_data_
=
train_data_
->
num_data
();
// create buffer for gradients and hessians
if
(
objective_function_
!=
nullptr
)
{
size_t
total_size
=
static_cast
<
size_t
>
(
num_data_
)
*
num_tree_per_iteration_
;
gradients_
.
resize
(
total_size
);
hessians_
.
resize
(
total_size
);
}
// get max feature index
max_feature_idx_
=
train_data_
->
num_total_features
()
-
1
;
// get label index
label_idx_
=
train_data_
->
label_idx
();
// get feature names
feature_names_
=
train_data_
->
feature_names
();
feature_infos_
=
train_data_
->
feature_infos
();
ResetBaggingConfig
(
gbdt_config_
.
get
());
tree_learner_
->
ResetTrainingData
(
train_data
);
}
}
}
void
GBDT
::
ResetConfig
(
const
BoostingConfig
*
config
)
{
auto
new_config
=
std
::
unique_ptr
<
BoostingConfig
>
(
new
BoostingConfig
(
*
config
));
early_stopping_round_
=
new_config
->
early_stopping_round
;
shrinkage_rate_
=
new_config
->
learning_rate
;
ResetBaggingConfig
(
new_config
.
get
());
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
gbdt_config_
.
reset
(
new_config
.
release
());
gbdt_config_
.
reset
(
new_config
.
release
());
}
}
void
GBDT
::
ResetBaggingConfig
(
const
BoostingConfig
*
config
)
{
// if need bagging, create buffer
if
(
config
->
bagging_fraction
<
1.0
&&
config
->
bagging_freq
>
0
)
{
bag_data_cnt_
=
static_cast
<
data_size_t
>
(
config
->
bagging_fraction
*
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
tmp_indices_
.
resize
(
num_data_
);
offsets_buf_
.
resize
(
num_threads_
);
left_cnts_buf_
.
resize
(
num_threads_
);
right_cnts_buf_
.
resize
(
num_threads_
);
left_write_pos_buf_
.
resize
(
num_threads_
);
right_write_pos_buf_
.
resize
(
num_threads_
);
double
average_bag_rate
=
config
->
bagging_fraction
/
config
->
bagging_freq
;
int
sparse_group
=
0
;
for
(
int
i
=
0
;
i
<
train_data_
->
num_feature_groups
();
++
i
)
{
if
(
train_data_
->
FeatureGroupIsSparse
(
i
))
{
++
sparse_group
;
}
}
is_use_subset_
=
false
;
const
int
group_threshold_usesubset
=
100
;
const
int
sparse_group_threshold_usesubset
=
train_data_
->
num_feature_groups
()
/
4
;
if
(
average_bag_rate
<=
0.5
&&
(
train_data_
->
num_feature_groups
()
<
group_threshold_usesubset
||
sparse_group
<
sparse_group_threshold_usesubset
))
{
tmp_subset_
.
reset
(
new
Dataset
(
bag_data_cnt_
));
tmp_subset_
->
CopyFeatureMapperFrom
(
train_data_
);
is_use_subset_
=
true
;
Log
::
Debug
(
"use subset for bagging"
);
}
}
else
{
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
tmp_indices_
.
clear
();
is_use_subset_
=
false
;
}
}
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
if
(
!
train_data_
->
CheckAlign
(
*
valid_data
))
{
if
(
!
train_data_
->
CheckAlign
(
*
valid_data
))
{
...
@@ -358,7 +405,7 @@ double LabelAverage(const float* label, data_size_t num_data) {
...
@@ -358,7 +405,7 @@ double LabelAverage(const float* label, data_size_t num_data) {
Network
::
Allreduce
(
reinterpret_cast
<
char
*>
(
&
init_score
),
Network
::
Allreduce
(
reinterpret_cast
<
char
*>
(
&
init_score
),
sizeof
(
init_score
),
sizeof
(
init_score
),
sizeof
(
init_score
),
sizeof
(
init_score
),
reinterpret_cast
<
char
*>
(
&
global_init_score
),
reinterpret_cast
<
char
*>
(
&
global_init_score
),
[](
const
char
*
src
,
char
*
dst
,
int
len
)
{
[]
(
const
char
*
src
,
char
*
dst
,
int
len
)
{
int
used_size
=
0
;
int
used_size
=
0
;
const
int
type_size
=
sizeof
(
double
);
const
int
type_size
=
sizeof
(
double
);
const
double
*
p1
;
const
double
*
p1
;
...
@@ -1027,7 +1074,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
...
@@ -1027,7 +1074,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
}
}
// sort the importance
// sort the importance
std
::
sort
(
pairs
.
begin
(),
pairs
.
end
(),
std
::
sort
(
pairs
.
begin
(),
pairs
.
end
(),
[](
const
std
::
pair
<
size_t
,
std
::
string
>&
lhs
,
[]
(
const
std
::
pair
<
size_t
,
std
::
string
>&
lhs
,
const
std
::
pair
<
size_t
,
std
::
string
>&
rhs
)
{
const
std
::
pair
<
size_t
,
std
::
string
>&
rhs
)
{
return
lhs
.
first
>
rhs
.
first
;
return
lhs
.
first
>
rhs
.
first
;
});
});
...
...
src/boosting/gbdt.h
View file @
1e7ccbbb
...
@@ -63,14 +63,10 @@ public:
...
@@ -63,14 +63,10 @@ public:
num_iteration_for_pred_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_tree_per_iteration_
;
num_iteration_for_pred_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_tree_per_iteration_
;
}
}
/*!
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
* \brief Reset training data for current boosting
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
* \param train_data Training data
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
void
ResetConfig
(
const
BoostingConfig
*
config
)
override
;
/*!
/*!
* \brief Adding a validation dataset
* \brief Adding a validation dataset
* \param valid_data Validation dataset
* \param valid_data Validation dataset
...
@@ -258,6 +254,7 @@ public:
...
@@ -258,6 +254,7 @@ public:
virtual
const
char
*
SubModelName
()
const
override
{
return
"tree"
;
}
virtual
const
char
*
SubModelName
()
const
override
{
return
"tree"
;
}
protected:
protected:
void
ResetBaggingConfig
(
const
BoostingConfig
*
config
);
/*!
/*!
* \brief Implement bagging logic
* \brief Implement bagging logic
* \param iter Current interation
* \param iter Current interation
...
...
src/boosting/goss.hpp
View file @
1e7ccbbb
...
@@ -41,21 +41,28 @@ public:
...
@@ -41,21 +41,28 @@ public:
void
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
void
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
GBDT
::
Init
(
config
,
train_data
,
objective_function
,
training_metrics
);
GBDT
::
Init
(
config
,
train_data
,
objective_function
,
training_metrics
);
ResetGoss
();
}
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
GBDT
::
ResetTrainingData
(
train_data
,
objective_function
,
training_metrics
);
ResetGoss
();
}
void
ResetConfig
(
const
BoostingConfig
*
config
)
override
{
GBDT
::
ResetConfig
(
config
);
ResetGoss
();
}
void
ResetGoss
()
{
CHECK
(
gbdt_config_
->
top_rate
+
gbdt_config_
->
other_rate
<=
1.0
f
);
CHECK
(
gbdt_config_
->
top_rate
+
gbdt_config_
->
other_rate
<=
1.0
f
);
CHECK
(
gbdt_config_
->
top_rate
>
0.0
f
&&
gbdt_config_
->
other_rate
>
0.0
f
);
CHECK
(
gbdt_config_
->
top_rate
>
0.0
f
&&
gbdt_config_
->
other_rate
>
0.0
f
);
if
(
gbdt_config_
->
bagging_freq
>
0
&&
gbdt_config_
->
bagging_fraction
!=
1.0
f
)
{
if
(
gbdt_config_
->
bagging_freq
>
0
&&
gbdt_config_
->
bagging_fraction
!=
1.0
f
)
{
Log
::
Fatal
(
"cannot use bagging in GOSS"
);
Log
::
Fatal
(
"cannot use bagging in GOSS"
);
}
}
Log
::
Info
(
"using GOSS"
);
Log
::
Info
(
"using GOSS"
);
}
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
if
(
config
->
bagging_freq
>
0
&&
config
->
bagging_fraction
!=
1.0
f
)
{
Log
::
Fatal
(
"cannot use bagging in GOSS"
);
}
GBDT
::
ResetTrainingData
(
config
,
train_data
,
objective_function
,
training_metrics
);
if
(
train_data_
==
nullptr
)
{
return
;
}
bag_data_indices_
.
resize
(
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
tmp_indices_
.
resize
(
num_data_
);
tmp_indices_
.
resize
(
num_data_
);
tmp_indice_right_
.
resize
(
num_data_
);
tmp_indice_right_
.
resize
(
num_data_
);
...
@@ -66,8 +73,8 @@ public:
...
@@ -66,8 +73,8 @@ public:
right_write_pos_buf_
.
resize
(
num_threads_
);
right_write_pos_buf_
.
resize
(
num_threads_
);
is_use_subset_
=
false
;
is_use_subset_
=
false
;
if
(
config
->
top_rate
+
config
->
other_rate
<=
0.5
)
{
if
(
gbdt_
config
_
->
top_rate
+
gbdt_
config
_
->
other_rate
<=
0.5
)
{
auto
bag_data_cnt
=
static_cast
<
data_size_t
>
((
config
->
top_rate
+
config
->
other_rate
)
*
num_data_
);
auto
bag_data_cnt
=
static_cast
<
data_size_t
>
((
gbdt_
config
_
->
top_rate
+
gbdt_
config
_
->
other_rate
)
*
num_data_
);
tmp_subset_
.
reset
(
new
Dataset
(
bag_data_cnt
));
tmp_subset_
.
reset
(
new
Dataset
(
bag_data_cnt
));
tmp_subset_
->
CopyFeatureMapperFrom
(
train_data_
);
tmp_subset_
->
CopyFeatureMapperFrom
(
train_data_
);
is_use_subset_
=
true
;
is_use_subset_
=
true
;
...
...
src/c_api.cpp
View file @
1e7ccbbb
...
@@ -51,11 +51,12 @@ public:
...
@@ -51,11 +51,12 @@ public:
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
nullptr
));
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
nullptr
));
train_data_
=
train_data
;
CreateObjectiveAndMetrics
();
// initialize the boosting
// initialize the boosting
boosting_
->
Init
(
&
config_
.
boosting_config
,
nullptr
,
objective_fun_
.
get
(),
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
ResetTrainingData
(
train_data
);
}
}
void
MergeFrom
(
const
Booster
*
other
)
{
void
MergeFrom
(
const
Booster
*
other
)
{
...
@@ -67,9 +68,7 @@ public:
...
@@ -67,9 +68,7 @@ public:
}
}
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
void
CreateObjectiveAndMetrics
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
train_data_
=
train_data
;
// create objective function
// create objective function
objective_fun_
.
reset
(
ObjectiveFunction
::
CreateObjectiveFunction
(
config_
.
objective_type
,
objective_fun_
.
reset
(
ObjectiveFunction
::
CreateObjectiveFunction
(
config_
.
objective_type
,
config_
.
objective_config
));
config_
.
objective_config
));
...
@@ -91,10 +90,18 @@ public:
...
@@ -91,10 +90,18 @@ public:
train_metric_
.
push_back
(
std
::
move
(
metric
));
train_metric_
.
push_back
(
std
::
move
(
metric
));
}
}
train_metric_
.
shrink_to_fit
();
train_metric_
.
shrink_to_fit
();
}
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
if
(
train_data
!=
train_data_
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
train_data_
=
train_data
;
CreateObjectiveAndMetrics
();
// reset the boosting
// reset the boosting
boosting_
->
ResetTrainingData
(
&
config_
.
boosting_config
,
train_data_
,
boosting_
->
ResetTrainingData
(
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
}
}
void
ResetConfig
(
const
char
*
parameters
)
{
void
ResetConfig
(
const
char
*
parameters
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
...
@@ -125,10 +132,11 @@ public:
...
@@ -125,10 +132,11 @@ public:
if
(
objective_fun_
!=
nullptr
)
{
if
(
objective_fun_
!=
nullptr
)
{
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
}
}
boosting_
->
ResetTrainingData
(
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
}
boosting_
->
ResetTrainingData
(
&
config_
.
boosting_config
,
train_data_
,
boosting_
->
ResetConfig
(
&
config_
.
boosting_config
);
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment