Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
46d21476
Unverified
Commit
46d21476
authored
May 06, 2019
by
Guolin Ke
Committed by
GitHub
May 06, 2019
Browse files
fix a bug when bagging with reset_config (#2149)
* fix a bug when bagging with reset_config * clean code
parent
2c41d15e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
30 additions
and
159 deletions
+30
-159
include/LightGBM/objective_function.h
include/LightGBM/objective_function.h
+1
-8
include/LightGBM/tree_learner.h
include/LightGBM/tree_learner.h
+1
-4
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+9
-4
src/boosting/rf.hpp
src/boosting/rf.hpp
+3
-1
src/objective/regression_objective.hpp
src/objective/regression_objective.hpp
+13
-96
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+2
-42
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+1
-4
No files found.
include/LightGBM/objective_function.h
View file @
46d21476
...
@@ -43,18 +43,11 @@ class ObjectiveFunction {
...
@@ -43,18 +43,11 @@ class ObjectiveFunction {
virtual
bool
IsRenewTreeOutput
()
const
{
return
false
;
}
virtual
bool
IsRenewTreeOutput
()
const
{
return
false
;
}
virtual
double
RenewTreeOutput
(
double
ori_output
,
const
double
*
,
virtual
double
RenewTreeOutput
(
double
ori_output
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
const
data_size_t
*
,
const
data_size_t
*
,
const
data_size_t
*
,
const
data_size_t
*
,
data_size_t
)
const
{
return
ori_output
;
}
data_size_t
)
const
{
return
ori_output
;
}
virtual
double
RenewTreeOutput
(
double
ori_output
,
double
,
const
data_size_t
*
,
const
data_size_t
*
,
data_size_t
)
const
{
return
ori_output
;
}
virtual
double
BoostFromScore
(
int
/*class_id*/
)
const
{
return
0.0
;
}
virtual
double
BoostFromScore
(
int
/*class_id*/
)
const
{
return
0.0
;
}
virtual
bool
ClassNeedTrain
(
int
/*class_id*/
)
const
{
return
true
;
}
virtual
bool
ClassNeedTrain
(
int
/*class_id*/
)
const
{
return
true
;
}
...
...
include/LightGBM/tree_learner.h
View file @
46d21476
...
@@ -77,12 +77,9 @@ class TreeLearner {
...
@@ -77,12 +77,9 @@ class TreeLearner {
*/
*/
virtual
void
AddPredictionToScore
(
const
Tree
*
tree
,
double
*
out_score
)
const
=
0
;
virtual
void
AddPredictionToScore
(
const
Tree
*
tree
,
double
*
out_score
)
const
=
0
;
virtual
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
const
double
*
prediction
,
virtual
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
=
0
;
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
=
0
;
virtual
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
double
prediction
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
=
0
;
TreeLearner
()
=
default
;
TreeLearner
()
=
default
;
/*! \brief Disable copy */
/*! \brief Disable copy */
TreeLearner
&
operator
=
(
const
TreeLearner
&
)
=
delete
;
TreeLearner
&
operator
=
(
const
TreeLearner
&
)
=
delete
;
...
...
src/boosting/gbdt.cpp
View file @
46d21476
...
@@ -364,7 +364,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
...
@@ -364,7 +364,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
if
(
new_tree
->
num_leaves
()
>
1
)
{
if
(
new_tree
->
num_leaves
()
>
1
)
{
should_continue
=
true
;
should_continue
=
true
;
tree_learner_
->
RenewTreeOutput
(
new_tree
.
get
(),
objective_function_
,
train_score_updater_
->
score
()
+
bias
,
auto
score_ptr
=
train_score_updater_
->
score
()
+
bias
;
auto
residual_getter
=
[
score_ptr
](
const
label_t
*
label
,
int
i
)
{
return
static_cast
<
double
>
(
label
[
i
])
-
score_ptr
[
i
];
};
tree_learner_
->
RenewTreeOutput
(
new_tree
.
get
(),
objective_function_
,
residual_getter
,
num_data_
,
bag_data_indices_
.
data
(),
bag_data_cnt_
);
num_data_
,
bag_data_indices_
.
data
(),
bag_data_cnt_
);
// shrinkage by learning rate
// shrinkage by learning rate
new_tree
->
Shrinkage
(
shrinkage_rate_
);
new_tree
->
Shrinkage
(
shrinkage_rate_
);
...
@@ -688,6 +690,11 @@ void GBDT::ResetConfig(const Config* config) {
...
@@ -688,6 +690,11 @@ void GBDT::ResetConfig(const Config* config) {
void
GBDT
::
ResetBaggingConfig
(
const
Config
*
config
,
bool
is_change_dataset
)
{
void
GBDT
::
ResetBaggingConfig
(
const
Config
*
config
,
bool
is_change_dataset
)
{
// if need bagging, create buffer
// if need bagging, create buffer
if
(
config
->
bagging_fraction
<
1.0
&&
config
->
bagging_freq
>
0
)
{
if
(
config
->
bagging_fraction
<
1.0
&&
config
->
bagging_freq
>
0
)
{
need_re_bagging_
=
false
;
if
(
!
is_change_dataset
&&
config_
.
get
()
!=
nullptr
&&
config_
->
bagging_fraction
==
config
->
bagging_fraction
&&
config_
->
bagging_freq
==
config
->
bagging_freq
)
{
return
;
}
bag_data_cnt_
=
bag_data_cnt_
=
static_cast
<
data_size_t
>
(
config
->
bagging_fraction
*
num_data_
);
static_cast
<
data_size_t
>
(
config
->
bagging_fraction
*
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
bag_data_indices_
.
resize
(
num_data_
);
...
@@ -719,9 +726,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
...
@@ -719,9 +726,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
Log
::
Debug
(
"Use subset for bagging"
);
Log
::
Debug
(
"Use subset for bagging"
);
}
}
if
(
is_change_dataset
)
{
need_re_bagging_
=
true
;
need_re_bagging_
=
true
;
}
if
(
is_use_subset_
&&
bag_data_cnt_
<
num_data_
)
{
if
(
is_use_subset_
&&
bag_data_cnt_
<
num_data_
)
{
if
(
objective_function_
==
nullptr
)
{
if
(
objective_function_
==
nullptr
)
{
...
...
src/boosting/rf.hpp
View file @
46d21476
...
@@ -130,7 +130,9 @@ class RF : public GBDT {
...
@@ -130,7 +130,9 @@ class RF : public GBDT {
}
}
if
(
new_tree
->
num_leaves
()
>
1
)
{
if
(
new_tree
->
num_leaves
()
>
1
)
{
tree_learner_
->
RenewTreeOutput
(
new_tree
.
get
(),
objective_function_
,
init_scores_
[
cur_tree_id
],
double
pred
=
init_scores_
[
cur_tree_id
];
auto
residual_getter
=
[
pred
](
const
label_t
*
label
,
int
i
)
{
return
static_cast
<
double
>
(
label
[
i
])
-
pred
;
};
tree_learner_
->
RenewTreeOutput
(
new_tree
.
get
(),
objective_function_
,
residual_getter
,
num_data_
,
bag_data_indices_
.
data
(),
bag_data_cnt_
);
num_data_
,
bag_data_indices_
.
data
(),
bag_data_cnt_
);
if
(
std
::
fabs
(
init_scores_
[
cur_tree_id
])
>
kEpsilon
)
{
if
(
std
::
fabs
(
init_scores_
[
cur_tree_id
])
>
kEpsilon
)
{
new_tree
->
AddBias
(
init_scores_
[
cur_tree_id
]);
new_tree
->
AddBias
(
init_scores_
[
cur_tree_id
]);
...
...
src/objective/regression_objective.hpp
View file @
46d21476
...
@@ -232,62 +232,30 @@ class RegressionL1loss: public RegressionL2loss {
...
@@ -232,62 +232,30 @@ class RegressionL1loss: public RegressionL2loss {
bool
IsRenewTreeOutput
()
const
override
{
return
true
;
}
bool
IsRenewTreeOutput
()
const
override
{
return
true
;
}
double
RenewTreeOutput
(
double
,
const
double
*
pred
,
double
RenewTreeOutput
(
double
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
bagging_mapper
,
const
data_size_t
*
bagging_mapper
,
data_size_t
num_data_in_leaf
)
const
override
{
data_size_t
num_data_in_leaf
)
const
override
{
const
double
alpha
=
0.5
;
const
double
alpha
=
0.5
;
if
(
weights_
==
nullptr
)
{
if
(
weights_
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (
label_[index_mapper[i]] - pred[
index_mapper[i]
]
)
#define data_reader(i) (
residual_getter(label_,
index_mapper[i]
)
)
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha
);
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef data_reader
}
else
{
}
else
{
#define data_reader(i) (
label_[bagging_mapper[index_mapper[i]]] - pred[
bagging_mapper[index_mapper[i]]
]
)
#define data_reader(i) (
residual_getter(label_,
bagging_mapper[index_mapper[i]]
)
)
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha
);
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef data_reader
}
}
}
else
{
}
else
{
if
(
bagging_mapper
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (
label_[index_mapper[i]] - pred[
index_mapper[i]
]
)
#define data_reader(i) (
residual_getter(label_,
index_mapper[i]
)
)
#define weight_reader(i) (weights_[index_mapper[i]])
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef data_reader
#undef weight_reader
#undef weight_reader
}
else
{
}
else
{
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef weight_reader
}
}
}
double
RenewTreeOutput
(
double
,
double
pred
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
bagging_mapper
,
data_size_t
num_data_in_leaf
)
const
override
{
const
double
alpha
=
0.5
;
if
(
weights_
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_[index_mapper[i]] - pred)
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
}
else
{
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
}
}
else
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef weight_reader
}
else
{
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef data_reader
...
@@ -552,60 +520,29 @@ class RegressionQuantileloss : public RegressionL2loss {
...
@@ -552,60 +520,29 @@ class RegressionQuantileloss : public RegressionL2loss {
bool
IsRenewTreeOutput
()
const
override
{
return
true
;
}
bool
IsRenewTreeOutput
()
const
override
{
return
true
;
}
double
RenewTreeOutput
(
double
,
const
double
*
pred
,
double
RenewTreeOutput
(
double
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
bagging_mapper
,
data_size_t
num_data_in_leaf
)
const
override
{
if
(
weights_
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
}
else
{
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
}
}
else
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
#undef weight_reader
}
else
{
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
#undef weight_reader
}
}
}
double
RenewTreeOutput
(
double
,
double
pred
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
bagging_mapper
,
const
data_size_t
*
bagging_mapper
,
data_size_t
num_data_in_leaf
)
const
override
{
data_size_t
num_data_in_leaf
)
const
override
{
if
(
weights_
==
nullptr
)
{
if
(
weights_
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_
[
index_mapper[i]
] - pred
)
#define data_reader(i)
(residual_getter
(label_
,
index_mapper[i]
)
)
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha_
);
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
#undef data_reader
}
else
{
}
else
{
#define data_reader(i) (label_
[
bagging_mapper[index_mapper[i]]
] - pred
)
#define data_reader(i)
(residual_getter
(label_
,
bagging_mapper[index_mapper[i]]
)
)
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha_
);
PercentileFun
(
double
,
data_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
#undef data_reader
}
}
}
else
{
}
else
{
if
(
bagging_mapper
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_
[
index_mapper[i]
] - pred
)
#define data_reader(i)
(residual_getter
(label_
,
index_mapper[i]
)
)
#define weight_reader(i) (weights_[index_mapper[i]])
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha_
);
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
#undef data_reader
#undef weight_reader
#undef weight_reader
}
else
{
}
else
{
#define data_reader(i) (label_
[
bagging_mapper[index_mapper[i]]
] - pred
)
#define data_reader(i)
(residual_getter
(label_
,
bagging_mapper[index_mapper[i]]
)
)
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha_
);
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha_
);
#undef data_reader
#undef data_reader
...
@@ -684,39 +621,19 @@ class RegressionMAPELOSS : public RegressionL1loss {
...
@@ -684,39 +621,19 @@ class RegressionMAPELOSS : public RegressionL1loss {
bool
IsRenewTreeOutput
()
const
override
{
return
true
;
}
bool
IsRenewTreeOutput
()
const
override
{
return
true
;
}
double
RenewTreeOutput
(
double
,
const
double
*
pred
,
double
RenewTreeOutput
(
double
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
bagging_mapper
,
data_size_t
num_data_in_leaf
)
const
override
{
const
double
alpha
=
0.5
;
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef weight_reader
}
else
{
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef weight_reader
}
}
double
RenewTreeOutput
(
double
,
double
pred
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
index_mapper
,
const
data_size_t
*
bagging_mapper
,
const
data_size_t
*
bagging_mapper
,
data_size_t
num_data_in_leaf
)
const
override
{
data_size_t
num_data_in_leaf
)
const
override
{
const
double
alpha
=
0.5
;
const
double
alpha
=
0.5
;
if
(
bagging_mapper
==
nullptr
)
{
if
(
bagging_mapper
==
nullptr
)
{
#define data_reader(i) (label_
[
index_mapper[i]
] - pred
)
#define data_reader(i)
(residual_getter
(label_
,
index_mapper[i]
)
)
#define weight_reader(i) (label_weight_[index_mapper[i]])
#define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef data_reader
#undef weight_reader
#undef weight_reader
}
else
{
}
else
{
#define data_reader(i) (label_
[
bagging_mapper[index_mapper[i]]
] - pred
)
#define data_reader(i)
(residual_getter
(label_
,
bagging_mapper[index_mapper[i]]
)
)
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
WeightedPercentileFun
(
double
,
data_reader
,
weight_reader
,
num_data_in_leaf
,
alpha
);
#undef data_reader
#undef data_reader
...
...
src/treelearner/serial_tree_learner.cpp
View file @
46d21476
...
@@ -851,7 +851,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
...
@@ -851,7 +851,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
}
}
void
SerialTreeLearner
::
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
const
double
*
prediction
,
void
SerialTreeLearner
::
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
{
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
{
if
(
obj
!=
nullptr
&&
obj
->
IsRenewTreeOutput
())
{
if
(
obj
!=
nullptr
&&
obj
->
IsRenewTreeOutput
())
{
CHECK
(
tree
->
num_leaves
()
<=
data_partition_
->
num_leaves
());
CHECK
(
tree
->
num_leaves
()
<=
data_partition_
->
num_leaves
());
...
@@ -869,47 +869,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
...
@@ -869,47 +869,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
auto
index_mapper
=
data_partition_
->
GetIndexOnLeaf
(
i
,
&
cnt_leaf_data
);
auto
index_mapper
=
data_partition_
->
GetIndexOnLeaf
(
i
,
&
cnt_leaf_data
);
if
(
cnt_leaf_data
>
0
)
{
if
(
cnt_leaf_data
>
0
)
{
// bag_mapper[index_mapper[i]]
// bag_mapper[index_mapper[i]]
const
double
new_output
=
obj
->
RenewTreeOutput
(
output
,
prediction
,
index_mapper
,
bag_mapper
,
cnt_leaf_data
);
const
double
new_output
=
obj
->
RenewTreeOutput
(
output
,
residual_getter
,
index_mapper
,
bag_mapper
,
cnt_leaf_data
);
tree
->
SetLeafOutput
(
i
,
new_output
);
}
else
{
CHECK
(
num_machines
>
1
);
tree
->
SetLeafOutput
(
i
,
0.0
);
n_nozeroworker_perleaf
[
i
]
=
0
;
}
}
if
(
num_machines
>
1
)
{
std
::
vector
<
double
>
outputs
(
tree
->
num_leaves
());
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
outputs
[
i
]
=
static_cast
<
double
>
(
tree
->
LeafOutput
(
i
));
}
Network
::
GlobalSum
(
outputs
);
Network
::
GlobalSum
(
n_nozeroworker_perleaf
);
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
tree
->
SetLeafOutput
(
i
,
outputs
[
i
]
/
n_nozeroworker_perleaf
[
i
]);
}
}
}
}
void
SerialTreeLearner
::
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
double
prediction
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
{
if
(
obj
!=
nullptr
&&
obj
->
IsRenewTreeOutput
())
{
CHECK
(
tree
->
num_leaves
()
<=
data_partition_
->
num_leaves
());
const
data_size_t
*
bag_mapper
=
nullptr
;
if
(
total_num_data
!=
num_data_
)
{
CHECK
(
bag_cnt
==
num_data_
);
bag_mapper
=
bag_indices
;
}
std
::
vector
<
int
>
n_nozeroworker_perleaf
(
tree
->
num_leaves
(),
1
);
int
num_machines
=
Network
::
num_machines
();
#pragma omp parallel for schedule(static)
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
const
double
output
=
static_cast
<
double
>
(
tree
->
LeafOutput
(
i
));
data_size_t
cnt_leaf_data
=
0
;
auto
index_mapper
=
data_partition_
->
GetIndexOnLeaf
(
i
,
&
cnt_leaf_data
);
if
(
cnt_leaf_data
>
0
)
{
// bag_mapper[index_mapper[i]]
const
double
new_output
=
obj
->
RenewTreeOutput
(
output
,
prediction
,
index_mapper
,
bag_mapper
,
cnt_leaf_data
);
tree
->
SetLeafOutput
(
i
,
new_output
);
tree
->
SetLeafOutput
(
i
,
new_output
);
}
else
{
}
else
{
CHECK
(
num_machines
>
1
);
CHECK
(
num_machines
>
1
);
...
...
src/treelearner/serial_tree_learner.h
View file @
46d21476
...
@@ -74,10 +74,7 @@ class SerialTreeLearner: public TreeLearner {
...
@@ -74,10 +74,7 @@ class SerialTreeLearner: public TreeLearner {
}
}
}
}
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
const
double
*
prediction
,
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
std
::
function
<
double
(
const
label_t
*
,
int
)
>
residual_getter
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
override
;
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
double
prediction
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
override
;
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
override
;
protected:
protected:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment