Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5392c9ea
Unverified
Commit
5392c9ea
authored
Jan 16, 2018
by
Guolin Ke
Committed by
GitHub
Jan 16, 2018
Browse files
Fix objective functions with zero hessian (#1199)
parent
d90369a0
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
0 deletions
+37
-0
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+34
-0
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+3
-0
No files found.
src/treelearner/serial_tree_learner.cpp
View file @
5392c9ea
#include "serial_tree_learner.h"
#include "serial_tree_learner.h"
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
#include <algorithm>
#include <algorithm>
#include <vector>
#include <vector>
...
@@ -587,4 +589,36 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
...
@@ -587,4 +589,36 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
}
}
}
}
void
SerialTreeLearner
::
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
const
double
*
prediction
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
{
if
(
obj
!=
nullptr
&&
obj
->
IsRenewTreeOutput
())
{
CHECK
(
tree
->
num_leaves
()
<=
data_partition_
->
num_leaves
());
const
data_size_t
*
bag_mapper
=
nullptr
;
if
(
total_num_data
!=
num_data_
)
{
CHECK
(
bag_cnt
==
num_data_
);
bag_mapper
=
bag_indices
;
}
#pragma omp parallel for schedule(static)
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
const
double
output
=
static_cast
<
double
>
(
tree
->
LeafOutput
(
i
));
data_size_t
cnt_leaf_data
=
0
;
auto
index_mapper
=
data_partition_
->
GetIndexOnLeaf
(
i
,
&
cnt_leaf_data
);
CHECK
(
cnt_leaf_data
>
0
);
// bag_mapper[index_mapper[i]]
const
double
new_output
=
obj
->
RenewTreeOutput
(
output
,
prediction
,
index_mapper
,
bag_mapper
,
cnt_leaf_data
);
tree
->
SetLeafOutput
(
i
,
new_output
);
}
if
(
Network
::
num_machines
()
>
1
)
{
std
::
vector
<
double
>
outputs
(
tree
->
num_leaves
());
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
outputs
[
i
]
=
static_cast
<
double
>
(
tree
->
LeafOutput
(
i
));
}
Network
::
GlobalSum
(
outputs
);
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
tree
->
SetLeafOutput
(
i
,
outputs
[
i
]
/
Network
::
num_machines
());
}
}
}
}
}
// namespace LightGBM
}
// namespace LightGBM
src/treelearner/serial_tree_learner.h
View file @
5392c9ea
...
@@ -66,6 +66,9 @@ public:
...
@@ -66,6 +66,9 @@ public:
}
}
}
}
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
const
double
*
prediction
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
override
;
protected:
protected:
/*!
/*!
* \brief Some initial works before training
* \brief Some initial works before training
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment