Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
63eddae0
Commit
63eddae0
authored
Nov 27, 2016
by
Guolin Ke
Browse files
provide a light weight interface for reset learning rate
parent
19512d82
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
28 deletions
+55
-28
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+6
-0
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+5
-3
src/application/application.cpp
src/application/application.cpp
+30
-24
src/boosting/gbdt.h
src/boosting/gbdt.h
+8
-0
src/c_api.cpp
src/c_api.cpp
+6
-1
No files found.
include/LightGBM/boosting.h
View file @
63eddae0
...
...
@@ -51,6 +51,12 @@ public:
*/
virtual
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
=
0
;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
virtual
void
ResetShrinkageRate
(
double
shrinkage_rate
)
=
0
;
/*!
* \brief Add a validation data
* \param valid_data Validation data
...
...
python-package/lightgbm/basic.py
View file @
63eddae0
...
...
@@ -128,9 +128,11 @@ def param_dict_to_str(data):
return
""
pairs
=
[]
for
key
,
val
in
data
.
items
():
if
isinstance
(
val
,
list
):
pairs
.
append
(
str
(
key
)
+
'='
+
','
.
join
(
val
))
elif
isinstance
(
val
,
(
int
,
float
,
str
,
bool
)):
if
is_str
(
val
):
pairs
.
append
(
str
(
key
)
+
'='
+
str
(
val
))
elif
isinstance
(
val
,
(
list
,
tuple
)):
pairs
.
append
(
str
(
key
)
+
'='
+
','
.
join
(
map
(
str
,
val
)))
elif
isinstance
(
val
,
(
int
,
float
,
bool
)):
pairs
.
append
(
str
(
key
)
+
'='
+
str
(
val
))
else
:
raise
TypeError
(
'unknow type of parameter:%s , got:%s'
%
(
key
,
type
(
val
).
__name__
))
...
...
src/application/application.cpp
View file @
63eddae0
...
...
@@ -144,6 +144,11 @@ void Application::LoadData() {
}
}
train_metric_
.
shrink_to_fit
();
if
(
config_
.
metric_types
.
size
()
>
0
)
{
// only when have metrics then need to construct validation data
// Add validation data, if it exists
for
(
size_t
i
=
0
;
i
<
config_
.
io_config
.
valid_data_filenames
.
size
();
++
i
)
{
// add
...
...
@@ -171,6 +176,7 @@ void Application::LoadData() {
}
valid_datas_
.
shrink_to_fit
();
valid_metrics_
.
shrink_to_fit
();
}
auto
end_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
// output used time on each iteration
Log
::
Info
(
"Finished loading data in %f seconds"
,
...
...
src/boosting/gbdt.h
View file @
63eddae0
...
...
@@ -68,6 +68,14 @@ public:
*/
void
ResetTrainingData
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
object_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
* \brief Reset shrinkage_rate data for current boosting
* \param shrinkage_rate Configs for boosting
*/
void
ResetShrinkageRate
(
double
shrinkage_rate
)
override
{
shrinkage_rate_
=
shrinkage_rate
;
}
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
...
...
src/c_api.cpp
View file @
63eddae0
...
...
@@ -72,8 +72,13 @@ public:
Log
::
Fatal
(
"cannot change boosting_type during training"
);
}
config_
.
Set
(
param
);
if
(
param
.
size
()
==
1
&&
(
param
.
count
(
"learning_rate"
)
||
param
.
count
(
"shrinkage_rate"
)))
{
// only need to set learning rate
boosting_
->
ResetShrinkageRate
(
config_
.
boosting_config
.
learning_rate
);
}
else
{
ResetTrainingData
(
train_data_
);
}
}
void
AddValidData
(
const
Dataset
*
valid_data
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment