Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
73bc8ed7
Unverified
Commit
73bc8ed7
authored
Mar 04, 2020
by
Guolin Ke
Committed by
GitHub
Mar 04, 2020
Browse files
shrinkage to internal values (#2853)
parent
7d700cd3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
27 deletions
+23
-27
include/LightGBM/tree.h
include/LightGBM/tree.h
+23
-27
No files found.
include/LightGBM/tree.h
View file @
73bc8ed7
...
@@ -89,12 +89,7 @@ class Tree {
...
@@ -89,12 +89,7 @@ class Tree {
/*! \brief Set the output of one leaf */
/*! \brief Set the output of one leaf */
inline
void
SetLeafOutput
(
int
leaf
,
double
output
)
{
inline
void
SetLeafOutput
(
int
leaf
,
double
output
)
{
// Prevent denormal values because they can cause std::out_of_range exception when converting strings to doubles
leaf_value_
[
leaf
]
=
MaybeRoundToZero
(
output
);
if
(
IsZero
(
output
))
{
leaf_value_
[
leaf
]
=
0
;
}
else
{
leaf_value_
[
leaf
]
=
output
;
}
}
}
/*!
/*!
...
@@ -156,40 +151,33 @@ class Tree {
...
@@ -156,40 +151,33 @@ class Tree {
/*! \brief Get the number of data points that fall at or below this node*/
/*! \brief Get the number of data points that fall at or below this node*/
inline
int
data_count
(
int
node
)
const
{
return
node
>=
0
?
internal_count_
[
node
]
:
leaf_count_
[
~
node
];
}
inline
int
data_count
(
int
node
)
const
{
return
node
>=
0
?
internal_count_
[
node
]
:
leaf_count_
[
~
node
];
}
/*!
/*!
* \brief Shrinkage for the tree's output
* \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the training process
* shrinkage rate (a.k.a learning rate) is used to tune the training process
* \param rate The factor of shrinkage
* \param rate The factor of shrinkage
*/
*/
inline
void
Shrinkage
(
double
rate
)
{
inline
void
Shrinkage
(
double
rate
)
{
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for
(
int
i
=
0
;
i
<
num_leaves_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_leaves_
-
1
;
++
i
)
{
double
new_leaf_value
=
leaf_value_
[
i
]
*
rate
;
leaf_value_
[
i
]
=
MaybeRoundToZero
(
leaf_value_
[
i
]
*
rate
);
// Prevent denormal values because they can cause std::out_of_range exception when converting strings to doubles
internal_value_
[
i
]
=
MaybeRoundToZero
(
internal_value_
[
i
]
*
rate
);
if
(
IsZero
(
new_leaf_value
))
{
leaf_value_
[
i
]
=
0
;
}
else
{
leaf_value_
[
i
]
=
new_leaf_value
;
}
}
}
leaf_value_
[
num_leaves_
-
1
]
=
MaybeRoundToZero
(
leaf_value_
[
num_leaves_
-
1
]
*
rate
);
shrinkage_
*=
rate
;
shrinkage_
*=
rate
;
}
}
inline
double
shrinkage
()
const
{
inline
double
shrinkage
()
const
{
return
shrinkage_
;
}
return
shrinkage_
;
}
inline
void
AddBias
(
double
val
)
{
inline
void
AddBias
(
double
val
)
{
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for
(
int
i
=
0
;
i
<
num_leaves_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_leaves_
-
1
;
++
i
)
{
double
new_leaf_value
=
val
+
leaf_value_
[
i
];
leaf_value_
[
i
]
=
MaybeRoundToZero
(
leaf_value_
[
i
]
+
val
);
// Prevent denormal values because they can cause std::out_of_range exception when converting strings to doubles
internal_value_
[
i
]
=
MaybeRoundToZero
(
internal_value_
[
i
]
+
val
);
if
(
IsZero
(
new_leaf_value
))
{
leaf_value_
[
i
]
=
0
;
}
else
{
leaf_value_
[
i
]
=
new_leaf_value
;
}
}
}
leaf_value_
[
num_leaves_
-
1
]
=
MaybeRoundToZero
(
leaf_value_
[
num_leaves_
-
1
]
+
val
);
// force to 1.0
// force to 1.0
shrinkage_
=
1.0
f
;
shrinkage_
=
1.0
f
;
}
}
...
@@ -217,6 +205,14 @@ class Tree {
...
@@ -217,6 +205,14 @@ class Tree {
}
}
}
}
inline
static
double
MaybeRoundToZero
(
double
fval
)
{
if
(
fval
>
-
kZeroThreshold
&&
fval
<=
kZeroThreshold
)
{
return
0
;
}
else
{
return
fval
;
}
}
inline
static
bool
GetDecisionType
(
int8_t
decision_type
,
int8_t
mask
)
{
inline
static
bool
GetDecisionType
(
int8_t
decision_type
,
int8_t
mask
)
{
return
(
decision_type
&
mask
)
>
0
;
return
(
decision_type
&
mask
)
>
0
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment