Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
629fc047
Commit
629fc047
authored
Nov 24, 2016
by
Guolin Ke
Browse files
more flexity python basic object
parent
b41e0f0a
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
483 additions
and
358 deletions
+483
-358
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+1
-0
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+424
-313
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+35
-32
src/boosting/gbdt.h
src/boosting/gbdt.h
+17
-0
src/c_api.cpp
src/c_api.cpp
+6
-13
No files found.
include/LightGBM/boosting.h
View file @
629fc047
...
@@ -37,6 +37,7 @@ public:
...
@@ -37,6 +37,7 @@ public:
/*!
/*!
* \brief Merge model from other boosting object
* \brief Merge model from other boosting object
Will insert to the front of current boosting object
* \param other
* \param other
*/
*/
virtual
void
MergeFrom
(
const
Boosting
*
other
)
=
0
;
virtual
void
MergeFrom
(
const
Boosting
*
other
)
=
0
;
...
...
python-package/lightgbm/basic.py
View file @
629fc047
This diff is collapsed.
Click to expand it.
src/boosting/gbdt.cpp
View file @
629fc047
...
@@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
gbdt_config_
=
config
;
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
t
ra
in_data_
=
train_data
;
ra
ndom_
=
Random
(
gbdt_config_
->
bagging_seed
)
;
// create tree learner
// create tree learner
tree_learner_
.
clear
();
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data
_
);
new_tree_learner
->
Init
(
train_data
);
// init tree learner
// init tree learner
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
}
...
@@ -63,24 +63,33 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -63,24 +63,33 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
training_metrics_
.
push_back
(
metric
);
training_metrics_
.
push_back
(
metric
);
}
}
training_metrics_
.
shrink_to_fit
();
training_metrics_
.
shrink_to_fit
();
// create score tracker
train_score_updater_
.
reset
(
new
ScoreUpdater
(
train_data_
,
num_class_
));
num_data_
=
train_data_
->
num_data
();
// create buffer for gradients and hessians
if
(
object_function_
!=
nullptr
)
{
gradients_
=
std
::
vector
<
score_t
>
(
num_data_
*
num_class_
);
hessians_
=
std
::
vector
<
score_t
>
(
num_data_
*
num_class_
);
}
sigmoid_
=
-
1.0
f
;
sigmoid_
=
-
1.0
f
;
if
(
object_function_
!=
nullptr
if
(
object_function_
!=
nullptr
&&
std
::
string
(
object_function_
->
GetName
())
==
std
::
string
(
"binary"
))
{
&&
std
::
string
(
object_function_
->
GetName
())
==
std
::
string
(
"binary"
))
{
// only binary classification need sigmoid transform
// only binary classification need sigmoid transform
sigmoid_
=
gbdt_config_
->
sigmoid
;
sigmoid_
=
gbdt_config_
->
sigmoid
;
}
}
if
(
train_data_
!=
train_data
)
{
// not same training data, need reset score and others
// create score tracker
train_score_updater_
.
reset
(
new
ScoreUpdater
(
train_data
,
num_class_
));
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
(
i
+
num_init_iteration_
)
*
num_class_
+
curr_class
;
train_score_updater_
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
}
}
num_data_
=
train_data
->
num_data
();
// create buffer for gradients and hessians
if
(
object_function_
!=
nullptr
)
{
gradients_
=
std
::
vector
<
score_t
>
(
num_data_
*
num_class_
);
hessians_
=
std
::
vector
<
score_t
>
(
num_data_
*
num_class_
);
}
// get max feature index
// get max feature index
max_feature_idx_
=
train_data
_
->
num_total_features
()
-
1
;
max_feature_idx_
=
train_data
->
num_total_features
()
-
1
;
// get label index
// get label index
label_idx_
=
train_data
_
->
label_idx
();
label_idx_
=
train_data
->
label_idx
();
// if need bagging, create buffer
// if need bagging, create buffer
if
(
gbdt_config_
->
bagging_fraction
<
1.0
&&
gbdt_config_
->
bagging_freq
>
0
)
{
if
(
gbdt_config_
->
bagging_fraction
<
1.0
&&
gbdt_config_
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
...
@@ -91,14 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -91,14 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
bag_data_cnt_
=
num_data_
;
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
bag_data_indices_
.
clear
();
}
}
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
i
*
num_class_
+
curr_class
;
train_score_updater_
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
}
}
}
train_data_
=
train_data
;
}
}
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
...
@@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
...
@@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
// update score
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
i
*
num_class_
+
curr_class
;
auto
curr_tree
=
(
i
+
num_init_iteration_
)
*
num_class_
+
curr_class
;
new_score_updater
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
new_score_updater
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
}
}
}
}
...
@@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
void
GBDT
::
RollbackOneIter
()
{
void
GBDT
::
RollbackOneIter
()
{
if
(
iter_
==
0
)
{
return
;
}
if
(
iter_
==
0
)
{
return
;
}
int
cur_iter
=
iter_
-
1
;
int
cur_iter
=
iter_
+
num_init_iteration_
-
1
;
// reset score
// reset score
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
cur_iter
*
num_class_
+
curr_class
;
auto
curr_tree
=
cur_iter
*
num_class_
+
curr_class
;
...
...
src/boosting/gbdt.h
View file @
629fc047
...
@@ -36,12 +36,28 @@ public:
...
@@ -36,12 +36,28 @@ public:
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
override
;
/*!
* \brief Merge model from other boosting object
Will insert to the front of current boosting object
* \param other
*/
void
MergeFrom
(
const
Boosting
*
other
)
override
{
void
MergeFrom
(
const
Boosting
*
other
)
override
{
auto
other_gbdt
=
reinterpret_cast
<
const
GBDT
*>
(
other
);
auto
other_gbdt
=
reinterpret_cast
<
const
GBDT
*>
(
other
);
// tmp move to other vector
auto
original_models
=
std
::
move
(
models_
);
models_
=
std
::
vector
<
std
::
unique_ptr
<
Tree
>>
();
// push model from other first
for
(
const
auto
&
tree
:
other_gbdt
->
models_
)
{
for
(
const
auto
&
tree
:
other_gbdt
->
models_
)
{
auto
new_tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
*
(
tree
.
get
())));
auto
new_tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
*
(
tree
.
get
())));
models_
.
push_back
(
std
::
move
(
new_tree
));
models_
.
push_back
(
std
::
move
(
new_tree
));
}
}
num_init_iteration_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_class_
;
// push model in current object
for
(
const
auto
&
tree
:
original_models
)
{
auto
new_tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
*
(
tree
.
get
())));
models_
.
push_back
(
std
::
move
(
new_tree
));
}
num_iteration_for_pred_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_class_
;
}
}
/*!
/*!
...
@@ -266,6 +282,7 @@ protected:
...
@@ -266,6 +282,7 @@ protected:
int
num_iteration_for_pred_
;
int
num_iteration_for_pred_
;
/*! \brief Shrinkage rate for one iteration */
/*! \brief Shrinkage rate for one iteration */
double
shrinkage_rate_
;
double
shrinkage_rate_
;
/*! \brief Number of loaded initial models */
int
num_init_iteration_
;
int
num_init_iteration_
;
};
};
...
...
src/c_api.cpp
View file @
629fc047
...
@@ -36,7 +36,7 @@ public:
...
@@ -36,7 +36,7 @@ public:
Log
::
Warning
(
"continued train from model is not support for c_api, \
Log
::
Warning
(
"continued train from model is not support for c_api, \
please use continued train with input score"
);
please use continued train with input score"
);
}
}
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
""
));
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
nullptr
));
ConstructObjectAndTrainingMetrics
(
train_data
);
ConstructObjectAndTrainingMetrics
(
train_data
);
// initialize the boosting
// initialize the boosting
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data
,
objective_fun_
.
get
(),
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data
,
objective_fun_
.
get
(),
...
@@ -114,6 +114,10 @@ public:
...
@@ -114,6 +114,10 @@ public:
return
boosting_
->
TrainOneIter
(
gradients
,
hessians
,
false
);
return
boosting_
->
TrainOneIter
(
gradients
,
hessians
,
false
);
}
}
void
RollbackOneIter
()
{
boosting_
->
RollbackOneIter
();
}
void
PrepareForPrediction
(
int
num_iteration
,
int
predict_type
)
{
void
PrepareForPrediction
(
int
num_iteration
,
int
predict_type
)
{
boosting_
->
SetNumIterationForPred
(
num_iteration
);
boosting_
->
SetNumIterationForPred
(
num_iteration
);
bool
is_predict_leaf
=
false
;
bool
is_predict_leaf
=
false
;
...
@@ -156,24 +160,13 @@ public:
...
@@ -156,24 +160,13 @@ public:
int
idx
=
0
;
int
idx
=
0
;
for
(
const
auto
&
metric
:
train_metric_
)
{
for
(
const
auto
&
metric
:
train_metric_
)
{
for
(
const
auto
&
name
:
metric
->
GetName
())
{
for
(
const
auto
&
name
:
metric
->
GetName
())
{
int
j
=
0
;
std
::
strcpy
(
out_strs
[
idx
],
name
.
c_str
());
auto
name_cstr
=
name
.
c_str
();
while
(
name_cstr
[
j
]
!=
'\0'
)
{
out_strs
[
idx
][
j
]
=
name_cstr
[
j
];
++
j
;
}
out_strs
[
idx
][
j
]
=
'\0'
;
++
idx
;
++
idx
;
}
}
}
}
return
idx
;
return
idx
;
}
}
void
RollbackOneIter
()
{
boosting_
->
RollbackOneIter
();
}
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
private:
private:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment