Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
3beee91d
Commit
3beee91d
authored
Mar 23, 2017
by
Guolin Ke
Browse files
only stop training when all classes are finshed in multi-class.
parent
2e962c77
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
12 deletions
+4
-12
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+4
-11
src/boosting/gbdt.h
src/boosting/gbdt.h
+0
-1
No files found.
src/boosting/gbdt.cpp
View file @
3beee91d
...
@@ -162,7 +162,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -162,7 +162,6 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if
(
train_data_
!=
nullptr
)
{
if
(
train_data_
!=
nullptr
)
{
// reset config for tree learner
// reset config for tree learner
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
is_class_end_
=
std
::
vector
<
bool
>
(
num_class_
,
false
);
}
}
gbdt_config_
.
reset
(
new_config
.
release
());
gbdt_config_
.
reset
(
new_config
.
release
());
}
}
...
@@ -284,7 +283,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
...
@@ -284,7 +283,7 @@ void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
#ifdef TIMETAG
#ifdef TIMETAG
auto
start_time
=
std
::
chrono
::
steady_clock
::
now
();
auto
start_time
=
std
::
chrono
::
steady_clock
::
now
();
#endif
#endif
// we need to predict out-of-bag s
o
cres of data for boosting
// we need to predict out-of-bag sc
o
res of data for boosting
if
(
num_data_
-
bag_data_cnt_
>
0
&&
!
is_use_subset_
)
{
if
(
num_data_
-
bag_data_cnt_
>
0
&&
!
is_use_subset_
)
{
train_score_updater_
->
AddScore
(
tree
,
bag_data_indices_
.
data
()
+
bag_data_cnt_
,
num_data_
-
bag_data_cnt_
,
curr_class
);
train_score_updater_
->
AddScore
(
tree
,
bag_data_indices_
.
data
()
+
bag_data_cnt_
,
num_data_
-
bag_data_cnt_
,
curr_class
);
}
}
...
@@ -351,7 +350,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -351,7 +350,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// get sub gradients
// get sub gradients
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
auto
bias
=
curr_class
*
num_data_
;
auto
bias
=
curr_class
*
num_data_
;
// cannot multi-threding
// cannot multi-thre
a
ding
for
(
int
i
=
0
;
i
<
bag_data_cnt_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
bag_data_cnt_
;
++
i
)
{
gradients_
[
bias
+
i
]
=
gradient
[
bias
+
bag_data_indices_
[
i
]];
gradients_
[
bias
+
i
]
=
gradient
[
bias
+
bag_data_indices_
[
i
]];
hessians_
[
bias
+
i
]
=
hessian
[
bias
+
bag_data_indices_
[
i
]];
hessians_
[
bias
+
i
]
=
hessian
[
bias
+
bag_data_indices_
[
i
]];
...
@@ -369,10 +368,8 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -369,10 +368,8 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
start_time
=
std
::
chrono
::
steady_clock
::
now
();
start_time
=
std
::
chrono
::
steady_clock
::
now
();
#endif
#endif
std
::
unique_ptr
<
Tree
>
new_tree
(
new
Tree
(
2
));
std
::
unique_ptr
<
Tree
>
new_tree
(
new
Tree
(
2
));
if
(
!
is_class_end_
[
curr_class
])
{
// train a new tree
// train a new tree
new_tree
.
reset
(
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
new_tree
.
reset
(
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
}
#ifdef TIMETAG
#ifdef TIMETAG
tree_time
+=
std
::
chrono
::
steady_clock
::
now
()
-
start_time
;
tree_time
+=
std
::
chrono
::
steady_clock
::
now
()
-
start_time
;
#endif
#endif
...
@@ -384,10 +381,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -384,10 +381,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// update score
// update score
UpdateScore
(
new_tree
.
get
(),
curr_class
);
UpdateScore
(
new_tree
.
get
(),
curr_class
);
UpdateScoreOutOfBag
(
new_tree
.
get
(),
curr_class
);
UpdateScoreOutOfBag
(
new_tree
.
get
(),
curr_class
);
}
else
{
is_class_end_
[
curr_class
]
=
true
;
}
}
// add model
// add model
models_
.
push_back
(
std
::
move
(
new_tree
));
models_
.
push_back
(
std
::
move
(
new_tree
));
}
}
...
@@ -423,7 +417,6 @@ void GBDT::RollbackOneIter() {
...
@@ -423,7 +417,6 @@ void GBDT::RollbackOneIter() {
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
for
(
int
curr_class
=
0
;
curr_class
<
num_class_
;
++
curr_class
)
{
models_
.
pop_back
();
models_
.
pop_back
();
}
}
is_class_end_
=
std
::
vector
<
bool
>
(
num_class_
,
false
);
--
iter_
;
--
iter_
;
}
}
...
...
src/boosting/gbdt.h
View file @
3beee91d
...
@@ -344,7 +344,6 @@ protected:
...
@@ -344,7 +344,6 @@ protected:
std
::
vector
<
data_size_t
>
right_write_pos_buf_
;
std
::
vector
<
data_size_t
>
right_write_pos_buf_
;
std
::
unique_ptr
<
Dataset
>
tmp_subset_
;
std
::
unique_ptr
<
Dataset
>
tmp_subset_
;
bool
is_use_subset_
;
bool
is_use_subset_
;
std
::
vector
<
bool
>
is_class_end_
;
bool
boost_from_average_
;
bool
boost_from_average_
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment