Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
63eddae0
Commit
63eddae0
authored
Nov 27, 2016
by
Guolin Ke
Browse files
provide a light weight interface for reset learning rate
parent
19512d82
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
28 deletions
+55
-28
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+6
-0
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+5
-3
src/application/application.cpp
src/application/application.cpp
+30
-24
src/boosting/gbdt.h
src/boosting/gbdt.h
+8
-0
src/c_api.cpp
src/c_api.cpp
+6
-1
No files found.
include/LightGBM/boosting.h
View file @
63eddae0
...
...
@@ -51,6 +51,12 @@ public:
*/
virtual
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual
void
ResetShrinkageRate
(
double
shrinkage_rate
)
=
0
;
/*!
* \brief Add a validation data
* \param valid_data Validation data
...
...
python-package/lightgbm/basic.py
View file @
63eddae0
...
...
@@ -128,9 +128,11 @@ def param_dict_to_str(data):
return
""
pairs
=
[]
for
key
,
val
in
data
.
items
():
if
isinstance
(
val
,
list
):
pairs
.
append
(
str
(
key
)
+
'='
+
','
.
join
(
val
))
elif
isinstance
(
val
,
(
int
,
float
,
str
,
bool
)):
if
is_str
(
val
):
pairs
.
append
(
str
(
key
)
+
'='
+
str
(
val
))
elif
isinstance
(
val
,
(
list
,
tuple
)):
pairs
.
append
(
str
(
key
)
+
'='
+
','
.
join
(
map
(
str
,
val
)))
elif
isinstance
(
val
,
(
int
,
float
,
bool
)):
pairs
.
append
(
str
(
key
)
+
'='
+
str
(
val
))
else
:
raise
TypeError
(
'unknow type of parameter:%s , got:%s'
%
(
key
,
type
(
val
).
__name__
))
...
...
src/application/application.cpp
View file @
63eddae0
...
...
@@ -144,33 +144,39 @@ void Application::LoadData() {
}
}
train_metric_
.
shrink_to_fit
();
// Add validation data, if it exists
for
(
size_t
i
=
0
;
i
<
config_
.
io_config
.
valid_data_filenames
.
size
();
++
i
)
{
// add
auto
new_dataset
=
std
::
unique_ptr
<
Dataset
>
(
dataset_loader
.
LoadFromFileAlignWithOtherDataset
(
config_
.
io_config
.
valid_data_filenames
[
i
].
c_str
(),
train_data_
.
get
())
);
valid_datas_
.
push_back
(
std
::
move
(
new_dataset
));
// need save binary file
if
(
config_
.
io_config
.
is_save_binary_file
)
{
valid_datas_
.
back
()
->
SaveBinaryFile
(
nullptr
);
}
// add metric for validation data
valid_metrics_
.
emplace_back
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
valid_datas_
.
back
()
->
metadata
(),
valid_datas_
.
back
()
->
num_data
());
valid_metrics_
.
back
().
push_back
(
std
::
move
(
metric
));
if
(
config_
.
metric_types
.
size
()
>
0
)
{
// only when have metrics then need to construct validation data
// Add validation data, if it exists
for
(
size_t
i
=
0
;
i
<
config_
.
io_config
.
valid_data_filenames
.
size
();
++
i
)
{
// add
auto
new_dataset
=
std
::
unique_ptr
<
Dataset
>
(
dataset_loader
.
LoadFromFileAlignWithOtherDataset
(
config_
.
io_config
.
valid_data_filenames
[
i
].
c_str
(),
train_data_
.
get
())
);
valid_datas_
.
push_back
(
std
::
move
(
new_dataset
));
// need save binary file
if
(
config_
.
io_config
.
is_save_binary_file
)
{
valid_datas_
.
back
()
->
SaveBinaryFile
(
nullptr
);
}
// add metric for validation data
valid_metrics_
.
emplace_back
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
if
(
metric
==
nullptr
)
{
continue
;
}
metric
->
Init
(
valid_datas_
.
back
()
->
metadata
(),
valid_datas_
.
back
()
->
num_data
());
valid_metrics_
.
back
().
push_back
(
std
::
move
(
metric
));
}
valid_metrics_
.
back
().
shrink_to_fit
();
}
valid_metrics_
.
back
().
shrink_to_fit
();
valid_datas_
.
shrink_to_fit
();
valid_metrics_
.
shrink_to_fit
();
}
valid_datas_
.
shrink_to_fit
();
valid_metrics_
.
shrink_to_fit
();
auto
end_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
// output used time on each iteration
Log
::
Info
(
"Finished loading data in %f seconds"
,
...
...
src/boosting/gbdt.h
View file @
63eddae0
...
...
@@ -68,6 +68,14 @@ public:
*/
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void
ResetShrinkageRate
(
double
shrinkage_rate
)
override
{
shrinkage_rate_
=
shrinkage_rate
;
}
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
...
...
src/c_api.cpp
View file @
63eddae0
...
...
@@ -72,7 +72,12 @@ public:
Log
::
Fatal
(
"cannot change boosting_type during training"
);
}
config_
.
Set
(
param
);
ResetTrainingData
(
train_data_
);
if
(
param
.
size
()
==
1
&&
(
param
.
count
(
"learning_rate"
)
||
param
.
count
(
"shrinkage_rate"
)))
{
// only need to set learning rate
boosting_
->
ResetShrinkageRate
(
config_
.
boosting_config
.
learning_rate
);
}
else
{
ResetTrainingData
(
train_data_
);
}
}
void
AddValidData
(
const
Dataset
*
valid_data
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment