Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5cc38e6a
Unverified
Commit
5cc38e6a
authored
May 15, 2020
by
Guolin Ke
Committed by
GitHub
May 15, 2020
Browse files
fix goss with constant hessian (#3077)
parent
9085f4e2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
6 deletions
+15
-6
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+3
-6
src/boosting/gbdt.h
src/boosting/gbdt.h
+7
-0
src/boosting/goss.hpp
src/boosting/goss.hpp
+5
-0
No files found.
src/boosting/gbdt.cpp
View file @
5cc38e6a
...
...
@@ -70,12 +70,11 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
objective_function_
=
objective_function
;
num_tree_per_iteration_
=
num_class_
;
if
(
objective_function_
!=
nullptr
)
{
is_constant_hessian_
=
objective_function_
->
IsConstantHessian
();
num_tree_per_iteration_
=
objective_function_
->
NumModelPerIteration
();
}
else
{
is_constant_hessian_
=
false
;
}
is_constant_hessian_
=
GetIsConstHessian
(
objective_function
);
tree_learner_
=
std
::
unique_ptr
<
TreeLearner
>
(
TreeLearner
::
CreateTreeLearner
(
config_
->
tree_learner
,
config_
->
device_type
,
config_
.
get
()));
// init tree learner
...
...
@@ -653,11 +652,9 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
objective_function_
=
objective_function
;
if
(
objective_function_
!=
nullptr
)
{
is_constant_hessian_
=
objective_function_
->
IsConstantHessian
();
CHECK_EQ
(
num_tree_per_iteration_
,
objective_function_
->
NumModelPerIteration
());
}
else
{
is_constant_hessian_
=
false
;
}
is_constant_hessian_
=
GetIsConstHessian
(
objective_function
);
// push training metrics
training_metrics_
.
clear
();
...
...
src/boosting/gbdt.h
View file @
5cc38e6a
...
...
@@ -375,6 +375,13 @@ class GBDT : public GBDTBase {
const
char
*
SubModelName
()
const
override
{
return
"tree"
;
}
protected:
virtual
bool
GetIsConstHessian
(
const
ObjectiveFunction
*
objective_function
)
{
if
(
objective_function
!=
nullptr
)
{
return
objective_function
->
IsConstantHessian
();
}
else
{
return
false
;
}
}
/*!
* \brief Print eval result and check early stopping
*/
...
...
src/boosting/goss.hpp
View file @
5cc38e6a
...
...
@@ -153,6 +153,11 @@ class GOSS: public GBDT {
bag_data_cnt_
);
}
}
protected:
bool
GetIsConstHessian
(
const
ObjectiveFunction
*
)
override
{
return
false
;
}
};
}
// namespace LightGBM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment