Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
14a67b7e
Commit
14a67b7e
authored
Nov 24, 2016
by
Guolin Ke
Browse files
support dynamic change training data and add validation data
parent
13329682
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
195 additions
and
99 deletions
+195
-99
include/LightGBM/bin.h
include/LightGBM/bin.h
+12
-0
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+9
-1
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+16
-7
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+21
-0
include/LightGBM/feature.h
include/LightGBM/feature.h
+7
-0
src/application/application.cpp
src/application/application.cpp
+1
-1
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+63
-35
src/boosting/gbdt.h
src/boosting/gbdt.h
+11
-2
src/c_api.cpp
src/c_api.cpp
+51
-49
tests/c_api_test/test.py
tests/c_api_test/test.py
+4
-4
No files found.
include/LightGBM/bin.h
View file @
14a67b7e
...
...
@@ -51,6 +51,18 @@ public:
explicit
BinMapper
(
const
void
*
memory
);
~
BinMapper
();
bool
CheckAlign
(
const
BinMapper
&
other
)
const
{
if
(
num_bin_
!=
other
.
num_bin_
)
{
return
false
;
}
for
(
int
i
=
0
;
i
<
num_bin_
;
++
i
)
{
if
(
bin_upper_bound_
[
i
]
!=
other
.
bin_upper_bound_
[
i
])
{
return
false
;
}
}
return
true
;
}
/*! \brief Get number of bins */
inline
int
num_bin
()
const
{
return
num_bin_
;
}
/*! \brief True if bin is trival (contains only one bin) */
...
...
include/LightGBM/boosting.h
View file @
14a67b7e
...
...
@@ -41,12 +41,20 @@ public:
*/
virtual
void
ResetConfig
(
const
BoostingConfig
*
config
)
=
0
;
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metric
*/
virtual
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
* \brief Add a validation data
* \param valid_data Validation data
* \param valid_metrics Metric for validation data
*/
virtual
void
AddDataset
(
const
Dataset
*
valid_data
,
virtual
void
Add
Valid
Dataset
(
const
Dataset
*
valid_data
,
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
=
0
;
/*!
...
...
include/LightGBM/c_api.h
View file @
14a67b7e
...
...
@@ -212,19 +212,12 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
/*!
* \brief create an new boosting learner
* \param train_data training data set
* \param valid_datas validation data sets
* \param valid_names names of validation data sets
* \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2'
* \param init_model_filename filename of model
* \prama out handle of created Booster
* \return 0 when success, -1 when failure happens
*/
DllExport
int
LGBM_BoosterCreate
(
const
DatesetHandle
train_data
,
const
DatesetHandle
valid_datas
[],
int
n_valid_datas
,
const
char
*
parameters
,
const
char
*
init_model_filename
,
BoosterHandle
*
out
);
/*!
...
...
@@ -247,6 +240,22 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/
DllExport
int
LGBM_BoosterFree
(
BoosterHandle
handle
);
/*!
* \brief Add new validation to booster
* \param valid_data validation data set
* \return 0 when success, -1 when failure happens
*/
DllExport
int
LGBM_BoosterAddValidData
(
BoosterHandle
handle
,
const
DatesetHandle
valid_data
);
/*!
* \brief Add new validation to booster
* \param train_data training data set
* \return 0 when success, -1 when failure happens
*/
DllExport
int
LGBM_BoosterResetTrainingData
(
BoosterHandle
handle
,
const
DatesetHandle
train_data
);
/*!
* \brief Reset config for current booster
* \param parameters format: 'key1=value1 key2=value2'
...
...
include/LightGBM/dataset.h
View file @
14a67b7e
...
...
@@ -277,6 +277,27 @@ public:
/*! \brief Destructor */
~
Dataset
();
bool
CheckAlign
(
const
Dataset
&
other
)
const
{
if
(
num_features_
!=
other
.
num_features_
)
{
return
false
;
}
if
(
num_total_features_
!=
other
.
num_total_features_
)
{
return
false
;
}
if
(
num_class_
!=
other
.
num_class_
)
{
return
false
;
}
if
(
label_idx_
!=
other
.
label_idx_
)
{
return
false
;
}
for
(
int
i
=
0
;
i
<
num_features_
;
++
i
)
{
if
(
!
features_
[
i
]
->
CheckAlign
(
*
(
other
.
features_
[
i
].
get
())))
{
return
false
;
}
}
return
true
;
}
inline
void
PushOneRow
(
int
tid
,
data_size_t
row_idx
,
const
std
::
vector
<
double
>&
feature_values
)
{
for
(
size_t
i
=
0
;
i
<
feature_values
.
size
()
&&
i
<
static_cast
<
size_t
>
(
num_total_features_
);
++
i
)
{
int
feature_idx
=
used_feature_map_
[
i
];
...
...
include/LightGBM/feature.h
View file @
14a67b7e
...
...
@@ -63,6 +63,13 @@ public:
~
Feature
()
{
}
bool
CheckAlign
(
const
Feature
&
other
)
const
{
if
(
feature_index_
!=
other
.
feature_index_
)
{
return
false
;
}
return
bin_mapper_
->
CheckAlign
(
*
(
other
.
bin_mapper_
.
get
()));
}
/*!
* \brief Push one record, will auto convert to bin and push to bin data
* \param tid Thread id
...
...
src/application/application.cpp
View file @
14a67b7e
...
...
@@ -207,7 +207,7 @@ void Application::InitTrain() {
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
// add validation data into boosting
for
(
size_t
i
=
0
;
i
<
valid_datas_
.
size
();
++
i
)
{
boosting_
->
AddDataset
(
valid_datas_
[
i
].
get
(),
boosting_
->
Add
Valid
Dataset
(
valid_datas_
[
i
].
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
[
i
]));
}
Log
::
Info
(
"Finished initializing training"
);
...
...
src/boosting/gbdt.cpp
View file @
14a67b7e
...
...
@@ -16,7 +16,10 @@
namespace
LightGBM
{
GBDT
::
GBDT
()
:
saved_model_size_
(
-
1
),
num_iteration_for_pred_
(
0
)
{
GBDT
::
GBDT
()
:
saved_model_size_
(
-
1
),
num_iteration_for_pred_
(
0
),
num_init_iteration_
(
0
)
{
}
...
...
@@ -33,8 +36,46 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_
=
0
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
train_data_
=
train_data
;
num_class_
=
config
->
num_class
;
train_data_
=
nullptr
;
ResetTrainingData
(
train_data
,
object_function
,
training_metrics
);
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
void
GBDT
::
ResetConfig
(
const
BoostingConfig
*
config
)
{
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
// create tree learner
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data_
);
// init tree learner
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
// if need bagging, create buffer
if
(
gbdt_config_
->
bagging_fraction
<
1.0
&&
gbdt_config_
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
}
else
{
out_of_bag_data_cnt_
=
0
;
out_of_bag_data_indices_
.
clear
();
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
}
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
void
GBDT
::
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
}
train_data_
=
train_data
;
// create tree learner
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
...
...
@@ -46,6 +87,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
tree_learner_
.
shrink_to_fit
();
object_function_
=
object_function
;
// push training metrics
training_metrics_
.
clear
();
for
(
const
auto
&
metric
:
training_metrics
)
{
training_metrics_
.
push_back
(
metric
);
}
...
...
@@ -59,7 +101,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
hessians_
=
std
::
vector
<
score_t
>
(
num_data_
*
num_class_
);
}
sigmoid_
=
-
1.0
f
;
if
(
object_function_
!=
nullptr
if
(
object_function_
!=
nullptr
&&
std
::
string
(
object_function_
->
GetName
())
==
std
::
string
(
"binary"
))
{
// only binary classification need sigmoid transform
sigmoid_
=
gbdt_config_
->
sigmoid
;
...
...
@@ -78,44 +120,29 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
}
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
void
GBDT
::
ResetConfig
(
const
BoostingConfig
*
config
)
{
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
// create tree learner
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data_
);
// init tree learner
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
// if need bagging, create buffer
if
(
gbdt_config_
->
bagging_fraction
<
1.0
&&
gbdt_config_
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
}
else
{
out_of_bag_data_cnt_
=
0
;
out_of_bag_data_indices_
.
clear
();
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
i
*
num_class_
+
curr_class
;
train_score_updater_
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
}
}
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
void
GBDT
::
AddDataset
(
const
Dataset
*
valid_data
,
void
GBDT
::
AddValidDataset
(
const
Dataset
*
valid_data
,
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
{
if
(
iter_
>
0
)
{
Log
::
Fatal
(
"
C
annot add validation data
after
training
started
"
);
if
(
!
train_data_
->
CheckAlign
(
*
valid_data
)
)
{
Log
::
Fatal
(
"
c
annot add validation data
, since it has different bin mappers with
training
data
"
);
}
// for a validation dataset, we need its score and metric
auto
new_score_updater
=
std
::
unique_ptr
<
ScoreUpdater
>
(
new
ScoreUpdater
(
valid_data
,
num_class_
));
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
curr_tree
=
i
*
num_class_
+
curr_class
;
new_score_updater
->
AddScore
(
models_
[
curr_tree
].
get
(),
curr_class
);
}
}
valid_score_updater_
.
push_back
(
std
::
move
(
new_score_updater
));
valid_metrics_
.
emplace_back
();
if
(
early_stopping_round_
>
0
)
{
...
...
@@ -499,6 +526,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
}
Log
::
Info
(
"Finished loading %d models"
,
models_
.
size
());
num_iteration_for_pred_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_class_
;
num_init_iteration_
=
num_iteration_for_pred_
;
}
std
::
string
GBDT
::
FeatureImportance
()
const
{
...
...
src/boosting/gbdt.h
View file @
14a67b7e
...
...
@@ -42,12 +42,20 @@ public:
*/
void
ResetConfig
(
const
BoostingConfig
*
config
)
override
;
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metric
*/
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
* \param valid_metrics Metrics for validation dataset
*/
void
AddDataset
(
const
Dataset
*
valid_data
,
void
Add
Valid
Dataset
(
const
Dataset
*
valid_data
,
const
std
::
vector
<
const
Metric
*>&
valid_metrics
)
override
;
/*!
* \brief Training logic
...
...
@@ -63,7 +71,7 @@ public:
*/
void
RollbackOneIter
()
override
;
int
GetCurrentIteration
()
const
override
{
return
iter_
;
}
int
GetCurrentIteration
()
const
override
{
return
iter_
+
num_init_iteration_
;
}
bool
EvalAndCheckEarlyStopping
()
override
;
...
...
@@ -256,6 +264,7 @@ protected:
int
num_iteration_for_pred_
;
/*! \brief Shrinkage rate for one iteration */
double
shrinkage_rate_
;
int
num_init_iteration_
;
};
}
// namespace LightGBM
...
...
src/c_api.cpp
View file @
14a67b7e
...
...
@@ -28,9 +28,7 @@ public:
}
Booster
(
const
Dataset
*
train_data
,
std
::
vector
<
const
Dataset
*>
valid_data
,
const
char
*
parameters
)
:
train_data_
(
train_data
),
valid_datas_
(
valid_data
)
{
const
char
*
parameters
)
{
config_
.
LoadFromString
(
parameters
);
// create boosting
if
(
config_
.
io_config
.
input_model
.
size
()
>
0
)
{
...
...
@@ -38,6 +36,17 @@ public:
please use continued train with input score"
);
}
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
""
));
ConstructObjectAndTrainingMetrics
(
train_data
);
// initialize the boosting
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
~
Booster
()
{
}
void
ConstructObjectAndTrainingMetrics
(
const
Dataset
*
train_data
)
{
// create objective function
objective_fun_
.
reset
(
ObjectiveFunction
::
CreateObjectiveFunction
(
config_
.
objective_type
,
config_
.
objective_config
));
...
...
@@ -45,48 +54,39 @@ public:
Log
::
Warning
(
"Using self-defined objective functions"
);
}
// create training metric
train_metric_
.
clear
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
train_data
_
->
metadata
(),
train_data
_
->
num_data
());
metric
->
Init
(
train_data
->
metadata
(),
train_data
->
num_data
());
train_metric_
.
push_back
(
std
::
move
(
metric
));
}
train_metric_
.
shrink_to_fit
();
// add metric for validation data
for
(
size_t
i
=
0
;
i
<
valid_datas_
.
size
();
++
i
)
{
valid_metrics_
.
emplace_back
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
valid_datas_
[
i
]
->
metadata
(),
valid_datas_
[
i
]
->
num_data
());
valid_metrics_
.
back
().
push_back
(
std
::
move
(
metric
));
}
valid_metrics_
.
back
().
shrink_to_fit
();
}
valid_metrics_
.
shrink_to_fit
();
// initialize the objective function
if
(
objective_fun_
!=
nullptr
)
{
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
}
// initialize the boosting
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
// add validation data into boosting
for
(
size_t
i
=
0
;
i
<
valid_datas_
.
size
();
++
i
)
{
boosting_
->
AddDataset
(
valid_datas_
[
i
],
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
[
i
]));
objective_fun_
->
Init
(
train_data
->
metadata
(),
train_data
->
num_data
());
}
}
void
LoadModelFromFile
(
const
char
*
filename
)
{
Boosting
::
LoadFileToBoosting
(
boosting_
.
get
(),
filename
);
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
ConstructObjectAndTrainingMetrics
(
train_data
);
// initialize the boosting
boosting_
->
ResetTrainingData
(
train_data
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
~
Booster
()
{
void
AddValidData
(
const
Dataset
*
valid_data
)
{
valid_metrics_
.
emplace_back
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
valid_data
->
metadata
(),
valid_data
->
num_data
());
valid_metrics_
.
back
().
push_back
(
std
::
move
(
metric
));
}
valid_metrics_
.
back
().
shrink_to_fit
();
boosting_
->
AddValidDataset
(
valid_data
,
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
.
back
()));
}
bool
TrainOneIter
()
{
return
boosting_
->
TrainOneIter
(
nullptr
,
nullptr
,
false
);
}
...
...
@@ -151,9 +151,7 @@ public:
}
void
ResetBoostingConfig
(
const
char
*
parameters
)
{
OverallConfig
new_config
;
new_config
.
LoadFromString
(
parameters
);
config_
.
boosting_config
=
new_config
.
boosting_config
;
config_
.
LoadFromString
(
parameters
);
boosting_
->
ResetConfig
(
&
config_
.
boosting_config
);
}
...
...
@@ -164,14 +162,9 @@ public:
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
private:
std
::
unique_ptr
<
Boosting
>
boosting_
;
/*! \brief All configs */
OverallConfig
config_
;
/*! \brief Training data */
const
Dataset
*
train_data_
;
/*! \brief Validation data */
std
::
vector
<
const
Dataset
*>
valid_datas_
;
/*! \brief Metric for training data */
std
::
vector
<
std
::
unique_ptr
<
Metric
>>
train_metric_
;
/*! \brief Metrics for validation data */
...
...
@@ -446,21 +439,11 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
// ---- start of booster
DllExport
int
LGBM_BoosterCreate
(
const
DatesetHandle
train_data
,
const
DatesetHandle
valid_datas
[],
int
n_valid_datas
,
const
char
*
parameters
,
const
char
*
init_model_filename
,
BoosterHandle
*
out
)
{
API_BEGIN
();
const
Dataset
*
p_train_data
=
reinterpret_cast
<
const
Dataset
*>
(
train_data
);
std
::
vector
<
const
Dataset
*>
p_valid_datas
;
for
(
int
i
=
0
;
i
<
n_valid_datas
;
++
i
)
{
p_valid_datas
.
emplace_back
(
reinterpret_cast
<
const
Dataset
*>
(
valid_datas
[
i
]));
}
auto
ret
=
std
::
unique_ptr
<
Booster
>
(
new
Booster
(
p_train_data
,
p_valid_datas
,
parameters
));
if
(
init_model_filename
!=
nullptr
)
{
ret
->
LoadModelFromFile
(
init_model_filename
);
}
auto
ret
=
std
::
unique_ptr
<
Booster
>
(
new
Booster
(
p_train_data
,
parameters
));
*
out
=
ret
.
release
();
API_END
();
}
...
...
@@ -482,6 +465,25 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END
();
}
DllExport
int
LGBM_BoosterAddValidData
(
BoosterHandle
handle
,
const
DatesetHandle
valid_data
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
const
Dataset
*
p_dataset
=
reinterpret_cast
<
const
Dataset
*>
(
valid_data
);
ref_booster
->
AddValidData
(
p_dataset
);
API_END
();
}
DllExport
int
LGBM_BoosterResetTrainingData
(
BoosterHandle
handle
,
const
DatesetHandle
train_data
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
const
Dataset
*
p_dataset
=
reinterpret_cast
<
const
Dataset
*>
(
train_data
);
ref_booster
->
ResetTrainingData
(
p_dataset
);
API_END
();
}
DllExport
int
LGBM_BoosterResetParameter
(
BoosterHandle
handle
,
const
char
*
parameters
)
{
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
...
...
tests/c_api_test/test.py
View file @
14a67b7e
...
...
@@ -174,10 +174,10 @@ def test_dataset():
test_free_dataset
(
train
)
def
test_booster
():
train
=
test_load_from_mat
(
'../../examples/binary_classification/binary.train'
,
None
)
test
=
[
test_load_from_mat
(
'../../examples/binary_classification/binary.test'
,
train
)
]
test
=
test_load_from_mat
(
'../../examples/binary_classification/binary.test'
,
train
)
booster
=
ctypes
.
c_void_p
()
LIB
.
LGBM_BoosterCreate
(
train
,
c_
array
(
ctypes
.
c_void_p
,
test
),
len
(
test
),
c_str
(
"app=binary metric=auc num_leaves=31 verbose=0"
),
None
,
ctypes
.
byref
(
booster
)
)
LIB
.
LGBM_BoosterCreate
(
train
,
c_
str
(
"app=binary metric=auc num_leaves=31 verbose=0"
),
ctypes
.
byref
(
booster
))
LIB
.
LGBM_BoosterAddValidData
(
booster
,
test
)
is_finished
=
ctypes
.
c_int
(
0
)
for
i
in
range
(
100
):
LIB
.
LGBM_BoosterUpdateOneIter
(
booster
,
ctypes
.
byref
(
is_finished
))
...
...
@@ -188,7 +188,7 @@ def test_booster():
LIB
.
LGBM_BoosterSaveModel
(
booster
,
-
1
,
c_str
(
'model.txt'
))
LIB
.
LGBM_BoosterFree
(
booster
)
test_free_dataset
(
train
)
test_free_dataset
(
test
[
0
]
)
test_free_dataset
(
test
)
booster2
=
ctypes
.
c_void_p
()
num_total_model
=
ctypes
.
c_long
()
LIB
.
LGBM_BoosterCreateFromModelfile
(
c_str
(
'model.txt'
),
ctypes
.
byref
(
num_total_model
),
ctypes
.
byref
(
booster2
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment