Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
b41e0f0a
Commit
b41e0f0a
authored
Nov 24, 2016
by
Guolin Ke
Browse files
more flexiable reset config/training data logic for boosting
parent
5b4ee9db
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
84 deletions
+64
-84
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+2
-6
include/LightGBM/config.h
include/LightGBM/config.h
+3
-1
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+7
-35
src/boosting/gbdt.h
src/boosting/gbdt.h
+1
-7
src/c_api.cpp
src/c_api.cpp
+48
-32
src/io/config.cpp
src/io/config.cpp
+3
-3
No files found.
include/LightGBM/boosting.h
View file @
b41e0f0a
...
@@ -40,19 +40,15 @@ public:
...
@@ -40,19 +40,15 @@ public:
* \param other
* \param other
*/
*/
virtual
void
MergeFrom
(
const
Boosting
*
other
)
=
0
;
virtual
void
MergeFrom
(
const
Boosting
*
other
)
=
0
;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
virtual
void
ResetConfig
(
const
BoostingConfig
*
config
)
=
0
;
/*!
/*!
* \brief Reset training data for current boosting
* \brief Reset training data for current boosting
* \param config Configs for boosting
* \param train_data Training data
* \param train_data Training data
* \param object_function Training objective function
* \param object_function Training objective function
* \param training_metrics Training metric
* \param training_metrics Training metric
*/
*/
virtual
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
virtual
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
/*!
* \brief Add a validation data
* \brief Add a validation data
...
...
include/LightGBM/config.h
View file @
b41e0f0a
...
@@ -72,6 +72,8 @@ public:
...
@@ -72,6 +72,8 @@ public:
inline
bool
GetBool
(
inline
bool
GetBool
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
,
const
std
::
string
&
name
,
bool
*
out
);
const
std
::
string
&
name
,
bool
*
out
);
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Str2Map
(
const
char
*
parameters
);
};
};
/*! \brief Types of boosting */
/*! \brief Types of boosting */
...
@@ -231,7 +233,7 @@ public:
...
@@ -231,7 +233,7 @@ public:
MetricConfig
metric_config
;
MetricConfig
metric_config
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
void
LoadFromString
(
const
char
*
str
);
private:
private:
void
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
void
GetBoostingType
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
);
...
...
src/boosting/gbdt.cpp
View file @
b41e0f0a
...
@@ -29,52 +29,23 @@ GBDT::~GBDT() {
...
@@ -29,52 +29,23 @@ GBDT::~GBDT() {
void
GBDT
::
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
void
GBDT
::
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
gbdt_config_
=
config
;
iter_
=
0
;
iter_
=
0
;
saved_model_size_
=
-
1
;
saved_model_size_
=
-
1
;
num_iteration_for_pred_
=
0
;
num_iteration_for_pred_
=
0
;
max_feature_idx_
=
0
;
max_feature_idx_
=
0
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
num_class_
=
config
->
num_class
;
num_class_
=
config
->
num_class
;
train_data_
=
nullptr
;
train_data_
=
nullptr
;
ResetTrainingData
(
train_data
,
object_function
,
training_metrics
);
ResetTrainingData
(
config
,
train_data
,
object_function
,
training_metrics
);
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
void
GBDT
::
ResetConfig
(
const
BoostingConfig
*
config
)
{
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
// create tree learner
tree_learner_
.
clear
();
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
auto
new_tree_learner
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
gbdt_config_
->
tree_learner_type
,
gbdt_config_
->
tree_config
));
new_tree_learner
->
Init
(
train_data_
);
// init tree learner
tree_learner_
.
push_back
(
std
::
move
(
new_tree_learner
));
}
tree_learner_
.
shrink_to_fit
();
// if need bagging, create buffer
if
(
gbdt_config_
->
bagging_fraction
<
1.0
&&
gbdt_config_
->
bagging_freq
>
0
)
{
out_of_bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
bag_data_indices_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
}
else
{
out_of_bag_data_cnt_
=
0
;
out_of_bag_data_indices_
.
clear
();
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
}
// initialize random generator
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
}
}
void
GBDT
::
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
void
GBDT
::
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
{
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
if
(
train_data_
!=
nullptr
&&
!
train_data_
->
CheckAlign
(
*
train_data
))
{
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
Log
::
Fatal
(
"cannot reset training data, since new training data has different bin mappers"
);
}
}
gbdt_config_
=
config
;
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
train_data_
=
train_data
;
train_data_
=
train_data
;
// create tree learner
// create tree learner
tree_learner_
.
clear
();
tree_learner_
.
clear
();
...
@@ -120,6 +91,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
...
@@ -120,6 +91,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
bag_data_cnt_
=
num_data_
;
bag_data_cnt_
=
num_data_
;
bag_data_indices_
.
clear
();
bag_data_indices_
.
clear
();
}
}
random_
=
Random
(
gbdt_config_
->
bagging_seed
);
// update score
// update score
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
iter_
;
++
i
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
...
...
src/boosting/gbdt.h
View file @
b41e0f0a
...
@@ -44,19 +44,13 @@ public:
...
@@ -44,19 +44,13 @@ public:
}
}
}
}
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
*/
void
ResetConfig
(
const
BoostingConfig
*
config
)
override
;
/*!
/*!
* \brief Reset training data for current boosting
* \brief Reset training data for current boosting
* \param train_data Training data
* \param train_data Training data
* \param object_function Training objective function
* \param object_function Training objective function
* \param training_metrics Training metric
* \param training_metrics Training metric
*/
*/
void
ResetTrainingData
(
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
/*!
* \brief Adding a validation dataset
* \brief Adding a validation dataset
...
...
src/c_api.cpp
View file @
b41e0f0a
...
@@ -29,7 +29,8 @@ public:
...
@@ -29,7 +29,8 @@ public:
Booster
(
const
Dataset
*
train_data
,
Booster
(
const
Dataset
*
train_data
,
const
char
*
parameters
)
{
const
char
*
parameters
)
{
config_
.
LoadFromString
(
parameters
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
config_
.
Set
(
param
);
// create boosting
// create boosting
if
(
config_
.
io_config
.
input_model
.
size
()
>
0
)
{
if
(
config_
.
io_config
.
input_model
.
size
()
>
0
)
{
Log
::
Warning
(
"continued train from model is not support for c_api, \
Log
::
Warning
(
"continued train from model is not support for c_api, \
...
@@ -74,9 +75,23 @@ public:
...
@@ -74,9 +75,23 @@ public:
}
}
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
ConstructObjectAndTrainingMetrics
(
train_data
);
train_data_
=
train_data
;
ConstructObjectAndTrainingMetrics
(
train_data_
);
// initialize the boosting
// initialize the boosting
boosting_
->
ResetTrainingData
(
train_data
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
boosting_
->
ResetTrainingData
(
&
config_
.
boosting_config
,
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
}
void
ResetConfig
(
const
char
*
parameters
)
{
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
if
(
param
.
count
(
"num_class"
))
{
Log
::
Fatal
(
"cannot change num class during training"
);
}
if
(
param
.
count
(
"boosting_type"
))
{
Log
::
Fatal
(
"cannot change boosting_type during training"
);
}
config_
.
Set
(
param
);
ResetTrainingData
(
train_data_
);
}
}
void
AddValidData
(
const
Dataset
*
valid_data
)
{
void
AddValidData
(
const
Dataset
*
valid_data
)
{
...
@@ -154,10 +169,6 @@ public:
...
@@ -154,10 +169,6 @@ public:
return
idx
;
return
idx
;
}
}
void
ResetBoostingConfig
(
const
char
*
parameters
)
{
config_
.
LoadFromString
(
parameters
);
boosting_
->
ResetConfig
(
&
config_
.
boosting_config
);
}
void
RollbackOneIter
()
{
void
RollbackOneIter
()
{
boosting_
->
RollbackOneIter
();
boosting_
->
RollbackOneIter
();
...
@@ -166,6 +177,7 @@ public:
...
@@ -166,6 +177,7 @@ public:
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
const
Boosting
*
GetBoosting
()
const
{
return
boosting_
.
get
();
}
private:
private:
const
Dataset
*
train_data_
;
std
::
unique_ptr
<
Boosting
>
boosting_
;
std
::
unique_ptr
<
Boosting
>
boosting_
;
/*! \brief All configs */
/*! \brief All configs */
OverallConfig
config_
;
OverallConfig
config_
;
...
@@ -193,9 +205,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
...
@@ -193,9 +205,10 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
const
DatesetHandle
*
reference
,
const
DatesetHandle
*
reference
,
DatesetHandle
*
out
)
{
DatesetHandle
*
out
)
{
API_BEGIN
();
API_BEGIN
();
OverallConfig
config
;
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
config
.
LoadFromString
(
parameters
);
IOConfig
io_config
;
DatasetLoader
loader
(
config
.
io_config
,
nullptr
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
loader
.
SetHeader
(
filename
);
loader
.
SetHeader
(
filename
);
if
(
reference
==
nullptr
)
{
if
(
reference
==
nullptr
)
{
*
out
=
loader
.
LoadFromFile
(
filename
);
*
out
=
loader
.
LoadFromFile
(
filename
);
...
@@ -224,15 +237,16 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
...
@@ -224,15 +237,16 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
const
DatesetHandle
*
reference
,
const
DatesetHandle
*
reference
,
DatesetHandle
*
out
)
{
DatesetHandle
*
out
)
{
API_BEGIN
();
API_BEGIN
();
OverallConfig
config
;
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
config
.
LoadFromString
(
parameters
);
IOConfig
io_config
;
DatasetLoader
loader
(
config
.
io_config
,
nullptr
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
std
::
unique_ptr
<
Dataset
>
ret
;
std
::
unique_ptr
<
Dataset
>
ret
;
auto
get_row_fun
=
RowFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
auto
get_row_fun
=
RowFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
if
(
reference
==
nullptr
)
{
if
(
reference
==
nullptr
)
{
// sample data first
// sample data first
Random
rand
(
config
.
io_config
.
data_random_seed
);
Random
rand
(
io_config
.
data_random_seed
);
const
int
sample_cnt
=
static_cast
<
int
>
(
nrow
<
config
.
io_config
.
bin_construct_sample_cnt
?
nrow
:
config
.
io_config
.
bin_construct_sample_cnt
);
const
int
sample_cnt
=
static_cast
<
int
>
(
nrow
<
io_config
.
bin_construct_sample_cnt
?
nrow
:
io_config
.
bin_construct_sample_cnt
);
auto
sample_indices
=
rand
.
Sample
(
nrow
,
sample_cnt
);
auto
sample_indices
=
rand
.
Sample
(
nrow
,
sample_cnt
);
std
::
vector
<
std
::
vector
<
double
>>
sample_values
(
ncol
);
std
::
vector
<
std
::
vector
<
double
>>
sample_values
(
ncol
);
for
(
size_t
i
=
0
;
i
<
sample_indices
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sample_indices
.
size
();
++
i
)
{
...
@@ -246,10 +260,10 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
...
@@ -246,10 +260,10 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
}
}
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
}
else
{
}
else
{
ret
.
reset
(
new
Dataset
(
nrow
,
config
.
io_config
.
num_class
));
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
ret
->
CopyFeatureMapperFrom
(
ret
->
CopyFeatureMapperFrom
(
reinterpret_cast
<
const
Dataset
*>
(
*
reference
),
reinterpret_cast
<
const
Dataset
*>
(
*
reference
),
config
.
io_config
.
is_enable_sparse
);
io_config
.
is_enable_sparse
);
}
}
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
...
@@ -275,16 +289,17 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
...
@@ -275,16 +289,17 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
const
DatesetHandle
*
reference
,
const
DatesetHandle
*
reference
,
DatesetHandle
*
out
)
{
DatesetHandle
*
out
)
{
API_BEGIN
();
API_BEGIN
();
OverallConfig
config
;
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
config
.
LoadFromString
(
parameters
);
IOConfig
io_config
;
DatasetLoader
loader
(
config
.
io_config
,
nullptr
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
std
::
unique_ptr
<
Dataset
>
ret
;
std
::
unique_ptr
<
Dataset
>
ret
;
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
nindptr
-
1
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
nindptr
-
1
);
if
(
reference
==
nullptr
)
{
if
(
reference
==
nullptr
)
{
// sample data first
// sample data first
Random
rand
(
config
.
io_config
.
data_random_seed
);
Random
rand
(
io_config
.
data_random_seed
);
const
int
sample_cnt
=
static_cast
<
int
>
(
nrow
<
config
.
io_config
.
bin_construct_sample_cnt
?
nrow
:
config
.
io_config
.
bin_construct_sample_cnt
);
const
int
sample_cnt
=
static_cast
<
int
>
(
nrow
<
io_config
.
bin_construct_sample_cnt
?
nrow
:
io_config
.
bin_construct_sample_cnt
);
auto
sample_indices
=
rand
.
Sample
(
nrow
,
sample_cnt
);
auto
sample_indices
=
rand
.
Sample
(
nrow
,
sample_cnt
);
std
::
vector
<
std
::
vector
<
double
>>
sample_values
;
std
::
vector
<
std
::
vector
<
double
>>
sample_values
;
for
(
size_t
i
=
0
;
i
<
sample_indices
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sample_indices
.
size
();
++
i
)
{
...
@@ -307,10 +322,10 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
...
@@ -307,10 +322,10 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
CHECK
(
num_col
>=
static_cast
<
int
>
(
sample_values
.
size
()));
CHECK
(
num_col
>=
static_cast
<
int
>
(
sample_values
.
size
()));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
}
else
{
}
else
{
ret
.
reset
(
new
Dataset
(
nrow
,
config
.
io_config
.
num_class
));
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
ret
->
CopyFeatureMapperFrom
(
ret
->
CopyFeatureMapperFrom
(
reinterpret_cast
<
const
Dataset
*>
(
*
reference
),
reinterpret_cast
<
const
Dataset
*>
(
*
reference
),
config
.
io_config
.
is_enable_sparse
);
io_config
.
is_enable_sparse
);
}
}
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
...
@@ -336,17 +351,18 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
...
@@ -336,17 +351,18 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
const
DatesetHandle
*
reference
,
const
DatesetHandle
*
reference
,
DatesetHandle
*
out
)
{
DatesetHandle
*
out
)
{
API_BEGIN
();
API_BEGIN
();
OverallConfig
config
;
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
config
.
LoadFromString
(
parameters
);
IOConfig
io_config
;
DatasetLoader
loader
(
config
.
io_config
,
nullptr
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
std
::
unique_ptr
<
Dataset
>
ret
;
std
::
unique_ptr
<
Dataset
>
ret
;
auto
get_col_fun
=
ColumnFunctionFromCSC
(
col_ptr
,
col_ptr_type
,
indices
,
data
,
data_type
,
ncol_ptr
,
nelem
);
auto
get_col_fun
=
ColumnFunctionFromCSC
(
col_ptr
,
col_ptr_type
,
indices
,
data
,
data_type
,
ncol_ptr
,
nelem
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
num_row
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
num_row
);
if
(
reference
==
nullptr
)
{
if
(
reference
==
nullptr
)
{
Log
::
Warning
(
"Construct from CSC format is not efficient"
);
Log
::
Warning
(
"Construct from CSC format is not efficient"
);
// sample data first
// sample data first
Random
rand
(
config
.
io_config
.
data_random_seed
);
Random
rand
(
io_config
.
data_random_seed
);
const
int
sample_cnt
=
static_cast
<
int
>
(
nrow
<
config
.
io_config
.
bin_construct_sample_cnt
?
nrow
:
config
.
io_config
.
bin_construct_sample_cnt
);
const
int
sample_cnt
=
static_cast
<
int
>
(
nrow
<
io_config
.
bin_construct_sample_cnt
?
nrow
:
io_config
.
bin_construct_sample_cnt
);
auto
sample_indices
=
rand
.
Sample
(
nrow
,
sample_cnt
);
auto
sample_indices
=
rand
.
Sample
(
nrow
,
sample_cnt
);
std
::
vector
<
std
::
vector
<
double
>>
sample_values
(
ncol_ptr
-
1
);
std
::
vector
<
std
::
vector
<
double
>>
sample_values
(
ncol_ptr
-
1
);
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
...
@@ -356,10 +372,10 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
...
@@ -356,10 +372,10 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
}
}
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
}
else
{
}
else
{
ret
.
reset
(
new
Dataset
(
nrow
,
config
.
io_config
.
num_class
));
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
ret
->
CopyFeatureMapperFrom
(
ret
->
CopyFeatureMapperFrom
(
reinterpret_cast
<
const
Dataset
*>
(
*
reference
),
reinterpret_cast
<
const
Dataset
*>
(
*
reference
),
config
.
io_config
.
is_enable_sparse
);
io_config
.
is_enable_sparse
);
}
}
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
...
@@ -500,7 +516,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
...
@@ -500,7 +516,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
DllExport
int
LGBM_BoosterResetParameter
(
BoosterHandle
handle
,
const
char
*
parameters
)
{
DllExport
int
LGBM_BoosterResetParameter
(
BoosterHandle
handle
,
const
char
*
parameters
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
Reset
Boosting
Config
(
parameters
);
ref_booster
->
ResetConfig
(
parameters
);
API_END
();
API_END
();
}
}
...
...
src/io/config.cpp
View file @
b41e0f0a
...
@@ -10,9 +10,9 @@
...
@@ -10,9 +10,9 @@
namespace
LightGBM
{
namespace
LightGBM
{
void
OverallConfig
::
LoadFromString
(
const
char
*
str
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ConfigBase
::
Str2Map
(
const
char
*
parameters
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
params
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
params
;
auto
args
=
Common
::
Split
(
str
,
"
\t\n\r
"
);
auto
args
=
Common
::
Split
(
parameters
,
"
\t\n\r
"
);
for
(
auto
arg
:
args
)
{
for
(
auto
arg
:
args
)
{
std
::
vector
<
std
::
string
>
tmp_strs
=
Common
::
Split
(
arg
.
c_str
(),
'='
);
std
::
vector
<
std
::
string
>
tmp_strs
=
Common
::
Split
(
arg
.
c_str
(),
'='
);
if
(
tmp_strs
.
size
()
==
2
)
{
if
(
tmp_strs
.
size
()
==
2
)
{
...
@@ -27,7 +27,7 @@ void OverallConfig::LoadFromString(const char* str) {
...
@@ -27,7 +27,7 @@ void OverallConfig::LoadFromString(const char* str) {
}
}
}
}
ParameterAlias
::
KeyAliasTransform
(
&
params
);
ParameterAlias
::
KeyAliasTransform
(
&
params
);
S
et
(
params
)
;
r
et
urn
params
;
}
}
void
OverallConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
void
OverallConfig
::
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment