Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
84fef715
Commit
84fef715
authored
Apr 24, 2018
by
Jerry Liu
Committed by
Guolin Ke
Apr 24, 2018
Browse files
add force_split functionality (#1310)
parent
71539cc2
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1428 additions
and
22 deletions
+1428
-22
CMakeLists.txt
CMakeLists.txt
+1
-1
docs/Parameters.rst
docs/Parameters.rst
+10
-0
examples/binary_classification/forced_splits.json
examples/binary_classification/forced_splits.json
+12
-0
examples/binary_classification/train.conf
examples/binary_classification/train.conf
+3
-0
include/LightGBM/config.h
include/LightGBM/config.h
+6
-1
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+7
-0
include/LightGBM/json11.hpp
include/LightGBM/json11.hpp
+232
-0
include/LightGBM/tree_learner.h
include/LightGBM/tree_learner.h
+5
-1
src/boosting/dart.hpp
src/boosting/dart.hpp
+2
-1
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+11
-2
src/boosting/gbdt.h
src/boosting/gbdt.h
+7
-1
src/boosting/rf.hpp
src/boosting/rf.hpp
+2
-1
src/io/config.cpp
src/io/config.cpp
+1
-0
src/io/json11.cpp
src/io/json11.cpp
+788
-0
src/treelearner/feature_histogram.hpp
src/treelearner/feature_histogram.hpp
+151
-5
src/treelearner/gpu_tree_learner.cpp
src/treelearner/gpu_tree_learner.cpp
+3
-2
src/treelearner/gpu_tree_learner.h
src/treelearner/gpu_tree_learner.h
+3
-1
src/treelearner/serial_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
+171
-3
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+10
-1
src/treelearner/voting_parallel_tree_learner.cpp
src/treelearner/voting_parallel_tree_learner.cpp
+3
-2
No files found.
CMakeLists.txt
View file @
84fef715
...
...
@@ -144,7 +144,7 @@ if(USE_MPI)
include_directories
(
${
MPI_CXX_INCLUDE_PATH
}
)
endif
(
USE_MPI
)
file
(
GLOB SOURCES
file
(
GLOB SOURCES
src/application/*.cpp
src/boosting/*.cpp
src/io/*.cpp
...
...
docs/Parameters.rst
View file @
84fef715
...
...
@@ -520,6 +520,16 @@ IO Parameters
- separate by ``,`` for multi-validation data
- ``forced_splits``, default=\ ``""``, type=string
- path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences.
- ``.json`` file can be arbitrarily nested, and each split contains ``feature``, ``threshold`` fields, as well as ``left`` and ``right``
fields representing subsplits. Categorical splits are forced in a one-hot fashion, with ``left`` representing the split containing
the feature value and ``right`` representing other values.
- see ``examples/binary_classification/forced_splits.json`` as an example.
Objective Parameters
--------------------
...
...
examples/binary_classification/forced_splits.json
0 → 100644
View file @
84fef715
{
"feature"
:
25
,
"threshold"
:
1.30
,
"left"
:
{
"feature"
:
26
,
"threshold"
:
0.85
},
"right"
:
{
"feature"
:
26
,
"threshold"
:
0.85
}
}
examples/binary_classification/train.conf
View file @
84fef715
...
...
@@ -109,3 +109,6 @@ local_listen_port = 12400
# machines list file for parallel training, alias: mlist
machine_list_file
=
mlist
.
txt
# # force splits
# forced_splits = forced_splits.json
include/LightGBM/config.h
View file @
84fef715
...
...
@@ -105,6 +105,7 @@ public:
std
::
string
output_result
=
"LightGBM_predict_result.txt"
;
std
::
string
convert_model
=
"gbdt_prediction.cpp"
;
std
::
string
input_model
=
""
;
int
verbosity
=
1
;
int
num_iteration_predict
=
-
1
;
bool
is_pre_partition
=
false
;
...
...
@@ -264,6 +265,9 @@ public:
std
::
string
device_type
=
kDefaultDevice
;
TreeConfig
tree_config
;
LIGHTGBM_EXPORT
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
/* filename of forced splits */
std
::
string
forcedsplits_filename
=
""
;
};
/*! \brief Config for Network */
...
...
@@ -482,7 +486,8 @@ struct ParameterAlias {
"histogram_pool_size"
,
"is_provide_training_metric"
,
"machine_list_filename"
,
"machines"
,
"zero_as_missing"
,
"init_score_file"
,
"valid_init_score_file"
,
"is_predict_contrib"
,
"max_cat_threshold"
,
"cat_smooth"
,
"min_data_per_group"
,
"cat_l2"
,
"max_cat_to_onehot"
,
"alpha"
,
"reg_sqrt"
,
"tweedie_variance_power"
,
"monotone_constraints"
,
"max_delta_step"
"alpha"
,
"reg_sqrt"
,
"tweedie_variance_power"
,
"monotone_constraints"
,
"max_delta_step"
,
"forced_splits"
});
std
::
unordered_map
<
std
::
string
,
std
::
string
>
tmp_map
;
for
(
const
auto
&
pair
:
*
params
)
{
...
...
include/LightGBM/dataset.h
View file @
84fef715
...
...
@@ -495,6 +495,13 @@ public:
return
feature_groups_
[
group
]
->
bin_mappers_
[
sub_feature
]
->
BinToValue
(
threshold
);
}
// given a real threshold, find the closest threshold bin
inline
uint32_t
BinThreshold
(
int
i
,
double
threshold_double
)
const
{
const
int
group
=
feature2group_
[
i
];
const
int
sub_feature
=
feature2subfeature_
[
i
];
return
feature_groups_
[
group
]
->
bin_mappers_
[
sub_feature
]
->
ValueToBin
(
threshold_double
);
}
inline
void
CreateOrderedBins
(
std
::
vector
<
std
::
unique_ptr
<
OrderedBin
>>*
ordered_bins
)
const
{
ordered_bins
->
resize
(
num_groups_
);
OMP_INIT_EX
();
...
...
include/LightGBM/json11.hpp
0 → 100644
View file @
84fef715
/* json11
*
* json11 is a tiny JSON library for C++11, providing JSON parsing and serialization.
*
* The core object provided by the library is json11::Json. A Json object represents any JSON
* value: null, bool, number (int or double), string (std::string), array (std::vector), or
* object (std::map).
*
* Json objects act like values: they can be assigned, copied, moved, compared for equality or
* order, etc. There are also helper methods Json::dump, to serialize a Json to a string, and
* Json::parse (static) to parse a std::string as a Json object.
*
* Internally, the various types of Json object are represented by the JsonValue class
* hierarchy.
*
* A note on numbers - JSON specifies the syntax of number formatting but not its semantics,
* so some JSON implementations distinguish between integers and floating-point numbers, while
* some don't. In json11, we choose the latter. Because some JSON implementations (namely
* Javascript itself) treat all numbers as the same type, distinguishing the two leads
* to JSON that will be *silently* changed by a round-trip through those implementations.
* Dangerous! To avoid that risk, json11 stores all numbers as double internally, but also
* provides integer helpers.
*
* Fortunately, double-precision IEEE754 ('double') can precisely store any integer in the
* range +/-2^53, which includes every 'int' on most systems. (Timestamps often use int64
* or long long to avoid the Y2038K problem; a double storing microseconds since some epoch
* will be exact for +/- 275 years.)
*/
/* Copyright (c) 2013 Dropbox, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#pragma once
#include <string>
#include <vector>
#include <map>
#include <memory>
#include <initializer_list>
#ifdef _MSC_VER
#if _MSC_VER <= 1800 // VS 2013
#ifndef noexcept
#define noexcept throw()
#endif
#ifndef snprintf
#define snprintf _snprintf_s
#endif
#endif
#endif
namespace
json11
{
enum
JsonParse
{
STANDARD
,
COMMENTS
};
class
JsonValue
;
class
Json
final
{
public:
// Types
enum
Type
{
NUL
,
NUMBER
,
BOOL
,
STRING
,
ARRAY
,
OBJECT
};
// Array and object typedefs
typedef
std
::
vector
<
Json
>
array
;
typedef
std
::
map
<
std
::
string
,
Json
>
object
;
// Constructors for the various types of JSON value.
Json
()
noexcept
;
// NUL
Json
(
std
::
nullptr_t
)
noexcept
;
// NUL
Json
(
double
value
);
// NUMBER
Json
(
int
value
);
// NUMBER
Json
(
bool
value
);
// BOOL
Json
(
const
std
::
string
&
value
);
// STRING
Json
(
std
::
string
&&
value
);
// STRING
Json
(
const
char
*
value
);
// STRING
Json
(
const
array
&
values
);
// ARRAY
Json
(
array
&&
values
);
// ARRAY
Json
(
const
object
&
values
);
// OBJECT
Json
(
object
&&
values
);
// OBJECT
// Implicit constructor: anything with a to_json() function.
template
<
class
T
,
class
=
decltype
(
&
T
::
to_json
)>
Json
(
const
T
&
t
)
:
Json
(
t
.
to_json
())
{}
// Implicit constructor: map-like objects (std::map, std::unordered_map, etc)
template
<
class
M
,
typename
std
::
enable_if
<
std
::
is_constructible
<
std
::
string
,
decltype
(
std
::
declval
<
M
>().
begin
()
->
first
)
>::
value
&&
std
::
is_constructible
<
Json
,
decltype
(
std
::
declval
<
M
>
().
begin
()
->
second
)
>::
value
,
int
>::
type
=
0
>
Json
(
const
M
&
m
)
:
Json
(
object
(
m
.
begin
(),
m
.
end
()))
{}
// Implicit constructor: vector-like objects (std::list, std::vector, std::set, etc)
template
<
class
V
,
typename
std
::
enable_if
<
std
::
is_constructible
<
Json
,
decltype
(
*
std
::
declval
<
V
>().
begin
())
>::
value
,
int
>::
type
=
0
>
Json
(
const
V
&
v
)
:
Json
(
array
(
v
.
begin
(),
v
.
end
()))
{}
// This prevents Json(some_pointer) from accidentally producing a bool. Use
// Json(bool(some_pointer)) if that behavior is desired.
Json
(
void
*
)
=
delete
;
// Accessors
Type
type
()
const
;
bool
is_null
()
const
{
return
type
()
==
NUL
;
}
bool
is_number
()
const
{
return
type
()
==
NUMBER
;
}
bool
is_bool
()
const
{
return
type
()
==
BOOL
;
}
bool
is_string
()
const
{
return
type
()
==
STRING
;
}
bool
is_array
()
const
{
return
type
()
==
ARRAY
;
}
bool
is_object
()
const
{
return
type
()
==
OBJECT
;
}
// Return the enclosed value if this is a number, 0 otherwise. Note that json11 does not
// distinguish between integer and non-integer numbers - number_value() and int_value()
// can both be applied to a NUMBER-typed object.
double
number_value
()
const
;
int
int_value
()
const
;
// Return the enclosed value if this is a boolean, false otherwise.
bool
bool_value
()
const
;
// Return the enclosed string if this is a string, "" otherwise.
const
std
::
string
&
string_value
()
const
;
// Return the enclosed std::vector if this is an array, or an empty vector otherwise.
const
array
&
array_items
()
const
;
// Return the enclosed std::map if this is an object, or an empty map otherwise.
const
object
&
object_items
()
const
;
// Return a reference to arr[i] if this is an array, Json() otherwise.
const
Json
&
operator
[](
size_t
i
)
const
;
// Return a reference to obj[key] if this is an object, Json() otherwise.
const
Json
&
operator
[](
const
std
::
string
&
key
)
const
;
// Serialize.
void
dump
(
std
::
string
&
out
)
const
;
std
::
string
dump
()
const
{
std
::
string
out
;
dump
(
out
);
return
out
;
}
// Parse. If parse fails, return Json() and assign an error message to err.
static
Json
parse
(
const
std
::
string
&
in
,
std
::
string
&
err
,
JsonParse
strategy
=
JsonParse
::
STANDARD
);
static
Json
parse
(
const
char
*
in
,
std
::
string
&
err
,
JsonParse
strategy
=
JsonParse
::
STANDARD
)
{
if
(
in
)
{
return
parse
(
std
::
string
(
in
),
err
,
strategy
);
}
else
{
err
=
"null input"
;
return
nullptr
;
}
}
// Parse multiple objects, concatenated or separated by whitespace
static
std
::
vector
<
Json
>
parse_multi
(
const
std
::
string
&
in
,
std
::
string
::
size_type
&
parser_stop_pos
,
std
::
string
&
err
,
JsonParse
strategy
=
JsonParse
::
STANDARD
);
static
inline
std
::
vector
<
Json
>
parse_multi
(
const
std
::
string
&
in
,
std
::
string
&
err
,
JsonParse
strategy
=
JsonParse
::
STANDARD
)
{
std
::
string
::
size_type
parser_stop_pos
;
return
parse_multi
(
in
,
parser_stop_pos
,
err
,
strategy
);
}
bool
operator
==
(
const
Json
&
rhs
)
const
;
bool
operator
<
(
const
Json
&
rhs
)
const
;
bool
operator
!=
(
const
Json
&
rhs
)
const
{
return
!
(
*
this
==
rhs
);
}
bool
operator
<=
(
const
Json
&
rhs
)
const
{
return
!
(
rhs
<
*
this
);
}
bool
operator
>
(
const
Json
&
rhs
)
const
{
return
(
rhs
<
*
this
);
}
bool
operator
>=
(
const
Json
&
rhs
)
const
{
return
!
(
*
this
<
rhs
);
}
/* has_shape(types, err)
*
* Return true if this is a JSON object and, for each item in types, has a field of
* the given type. If not, return false and set err to a descriptive message.
*/
typedef
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
Type
>>
shape
;
bool
has_shape
(
const
shape
&
types
,
std
::
string
&
err
)
const
;
private:
std
::
shared_ptr
<
JsonValue
>
m_ptr
;
};
// Internal class hierarchy - JsonValue objects are not exposed to users of this API.
class
JsonValue
{
protected:
friend
class
Json
;
friend
class
JsonInt
;
friend
class
JsonDouble
;
virtual
Json
::
Type
type
()
const
=
0
;
virtual
bool
equals
(
const
JsonValue
*
other
)
const
=
0
;
virtual
bool
less
(
const
JsonValue
*
other
)
const
=
0
;
virtual
void
dump
(
std
::
string
&
out
)
const
=
0
;
virtual
double
number_value
()
const
;
virtual
int
int_value
()
const
;
virtual
bool
bool_value
()
const
;
virtual
const
std
::
string
&
string_value
()
const
;
virtual
const
Json
::
array
&
array_items
()
const
;
virtual
const
Json
&
operator
[](
size_t
i
)
const
;
virtual
const
Json
::
object
&
object_items
()
const
;
virtual
const
Json
&
operator
[](
const
std
::
string
&
key
)
const
;
virtual
~
JsonValue
()
{}
};
}
// namespace json11
include/LightGBM/tree_learner.h
View file @
84fef715
...
...
@@ -4,9 +4,12 @@
#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include <LightGBM/json11.hpp>
#include <vector>
using
namespace
json11
;
namespace
LightGBM
{
/*! \brief forward declaration */
...
...
@@ -44,7 +47,8 @@ public:
* \param is_constant_hessian True if all hessians share the same value
* \return A trained tree
*/
virtual
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
)
=
0
;
virtual
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
,
Json
&
forced_split_json
)
=
0
;
/*!
* \brief use a existing tree to fit the new gradients and hessians.
...
...
src/boosting/dart.hpp
View file @
84fef715
...
...
@@ -32,7 +32,8 @@ public:
* \param training_metrics Training metrics
* \param output_model_filename Filename of output model
*/
void
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
void
Init
(
const
BoostingConfig
*
config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
{
GBDT
::
Init
(
config
,
train_data
,
objective_function
,
training_metrics
);
random_for_drop_
=
Random
(
gbdt_config_
->
drop_seed
);
...
...
src/boosting/gbdt.cpp
View file @
84fef715
...
...
@@ -3,7 +3,6 @@
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
...
...
@@ -75,6 +74,16 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
early_stopping_round_
=
gbdt_config_
->
early_stopping_round
;
shrinkage_rate_
=
gbdt_config_
->
learning_rate
;
std
::
string
forced_splits_path
=
config
->
forcedsplits_filename
;
//load forced_splits file
if
(
forced_splits_path
!=
""
)
{
std
::
ifstream
forced_splits_file
(
forced_splits_path
.
c_str
());
std
::
stringstream
buffer
;
buffer
<<
forced_splits_file
.
rdbuf
();
std
::
string
err
;
forced_splits_json_
=
Json
::
parse
(
buffer
.
str
(),
err
);
}
objective_function_
=
objective_function
;
num_tree_per_iteration_
=
num_class_
;
if
(
objective_function_
!=
nullptr
)
{
...
...
@@ -425,7 +434,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
hess
=
hessians_
.
data
()
+
bias
;
}
new_tree
.
reset
(
tree_learner_
->
Train
(
grad
,
hess
,
is_constant_hessian_
));
new_tree
.
reset
(
tree_learner_
->
Train
(
grad
,
hess
,
is_constant_hessian_
,
forced_splits_json_
));
}
#ifdef TIMETAG
...
...
src/boosting/gbdt.h
View file @
84fef715
...
...
@@ -4,6 +4,7 @@
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/json11.hpp>
#include "score_updater.hpp"
...
...
@@ -15,6 +16,8 @@
#include <mutex>
#include <map>
using
namespace
json11
;
namespace
LightGBM
{
/*!
...
...
@@ -40,7 +43,8 @@ public:
* \param objective_function Training objective function
* \param training_metrics Training metrics
*/
void
Init
(
const
BoostingConfig
*
gbdt_config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
void
Init
(
const
BoostingConfig
*
gbdt_config
,
const
Dataset
*
train_data
,
const
ObjectiveFunction
*
objective_function
,
const
std
::
vector
<
const
Metric
*>&
training_metrics
)
override
;
/*!
...
...
@@ -452,6 +456,8 @@ protected:
std
::
unique_ptr
<
ObjectiveFunction
>
loaded_objective_
;
bool
average_output_
;
bool
need_re_bagging_
;
Json
forced_splits_json_
;
};
}
// namespace LightGBM
...
...
src/boosting/rf.hpp
View file @
84fef715
...
...
@@ -112,7 +112,8 @@ public:
hess
=
tmp_hess_
.
data
()
+
bias
;
}
new_tree
.
reset
(
tree_learner_
->
Train
(
grad
,
hess
,
is_constant_hessian_
));
new_tree
.
reset
(
tree_learner_
->
Train
(
grad
,
hess
,
is_constant_hessian_
,
forced_splits_json_
));
}
if
(
new_tree
->
num_leaves
()
>
1
)
{
...
...
src/io/config.cpp
View file @
84fef715
...
...
@@ -466,6 +466,7 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool
(
params
,
"boost_from_average"
,
&
boost_from_average
);
GetDeviceType
(
params
,
&
device_type
);
GetTreeLearnerType
(
params
,
&
tree_learner_type
);
GetString
(
params
,
"forced_splits"
,
&
forcedsplits_filename
);
tree_config
.
Set
(
params
);
}
...
...
src/io/json11.cpp
0 → 100644
View file @
84fef715
This diff is collapsed.
Click to expand it.
src/treelearner/feature_histogram.hpp
View file @
84fef715
...
...
@@ -7,6 +7,7 @@
#include <LightGBM/dataset.h>
#include <cstring>
#include <cmath>
namespace
LightGBM
{
...
...
@@ -20,6 +21,7 @@ public:
int8_t
monotone_type
;
/*! \brief pointer of tree config */
const
TreeConfig
*
tree_config
;
BinType
bin_type
;
};
/*!
* \brief FeatureHistogram is used to construct and store a histogram for a feature.
...
...
@@ -43,10 +45,10 @@ public:
* \param feature the feature data for this histogram
* \param min_num_data_one_leaf minimal number of data in one leaf
*/
void
Init
(
HistogramBinEntry
*
data
,
const
FeatureMetainfo
*
meta
,
BinType
bin_type
)
{
void
Init
(
HistogramBinEntry
*
data
,
const
FeatureMetainfo
*
meta
)
{
meta_
=
meta
;
data_
=
data
;
if
(
bin_type
==
BinType
::
NumericalBin
)
{
if
(
meta_
->
bin_type
==
BinType
::
NumericalBin
)
{
find_best_threshold_fun_
=
std
::
bind
(
&
FeatureHistogram
::
FindBestThresholdNumerical
,
this
,
std
::
placeholders
::
_1
,
std
::
placeholders
::
_2
,
std
::
placeholders
::
_3
,
std
::
placeholders
::
_4
,
std
::
placeholders
::
_5
,
std
::
placeholders
::
_6
);
}
else
{
...
...
@@ -105,7 +107,8 @@ public:
output
->
max_constraint
=
max_constraint
;
}
void
FindBestThresholdCategorical
(
double
sum_gradient
,
double
sum_hessian
,
data_size_t
num_data
,
double
min_constraint
,
double
max_constraint
,
void
FindBestThresholdCategorical
(
double
sum_gradient
,
double
sum_hessian
,
data_size_t
num_data
,
double
min_constraint
,
double
max_constraint
,
SplitInfo
*
output
)
{
output
->
default_left
=
false
;
double
best_gain
=
kMinScore
;
...
...
@@ -267,6 +270,149 @@ public:
}
}
void
GatherInfoForThreshold
(
double
sum_gradient
,
double
sum_hessian
,
uint32_t
threshold
,
data_size_t
num_data
,
SplitInfo
*
output
)
{
if
(
meta_
->
bin_type
==
BinType
::
NumericalBin
)
{
GatherInfoForThresholdNumerical
(
sum_gradient
,
sum_hessian
,
threshold
,
num_data
,
output
);
}
else
{
GatherInfoForThresholdCategorical
(
sum_gradient
,
sum_hessian
,
threshold
,
num_data
,
output
);
}
}
void
GatherInfoForThresholdNumerical
(
double
sum_gradient
,
double
sum_hessian
,
uint32_t
threshold
,
data_size_t
num_data
,
SplitInfo
*
output
)
{
double
gain_shift
=
GetLeafSplitGain
(
sum_gradient
,
sum_hessian
,
meta_
->
tree_config
->
lambda_l1
,
meta_
->
tree_config
->
lambda_l2
,
meta_
->
tree_config
->
max_delta_step
);
double
min_gain_shift
=
gain_shift
+
meta_
->
tree_config
->
min_gain_to_split
;
// do stuff here
const
int8_t
bias
=
meta_
->
bias
;
double
sum_right_gradient
=
0.0
f
;
double
sum_right_hessian
=
kEpsilon
;
data_size_t
right_count
=
0
;
// set values
bool
use_na_as_missing
;
bool
skip_default_bin
;
if
(
meta_
->
missing_type
==
MissingType
::
Zero
)
{
skip_default_bin
=
true
;
use_na_as_missing
=
false
;
}
else
{
skip_default_bin
=
false
;
use_na_as_missing
=
true
;
}
int
t
=
meta_
->
num_bin
-
1
-
bias
-
use_na_as_missing
;
const
int
t_end
=
1
-
bias
;
// from right to left, and we don't need data in bin0
for
(;
t
>=
t_end
;
--
t
)
{
if
(
static_cast
<
uint32_t
>
(
t
+
bias
)
<
threshold
)
{
break
;
}
// need to skip default bin
if
(
skip_default_bin
&&
(
t
+
bias
)
==
static_cast
<
int
>
(
meta_
->
default_bin
))
{
continue
;
}
sum_right_gradient
+=
data_
[
t
].
sum_gradients
;
sum_right_hessian
+=
data_
[
t
].
sum_hessians
;
right_count
+=
data_
[
t
].
cnt
;
}
double
sum_left_gradient
=
sum_gradient
-
sum_right_gradient
;
double
sum_left_hessian
=
sum_hessian
-
sum_right_hessian
;
data_size_t
left_count
=
num_data
-
right_count
;
double
current_gain
=
GetLeafSplitGain
(
sum_left_gradient
,
sum_left_hessian
,
meta_
->
tree_config
->
lambda_l1
,
meta_
->
tree_config
->
lambda_l2
,
meta_
->
tree_config
->
max_delta_step
)
+
GetLeafSplitGain
(
sum_right_gradient
,
sum_right_hessian
,
meta_
->
tree_config
->
lambda_l1
,
meta_
->
tree_config
->
lambda_l2
,
meta_
->
tree_config
->
max_delta_step
);
// gain with split is worse than without split
if
(
std
::
isnan
(
current_gain
)
||
current_gain
<=
min_gain_shift
)
{
output
->
gain
=
kMinScore
;
Log
::
Warning
(
"Gain with forced split worse than without split"
);
return
;
};
// update split information
output
->
threshold
=
threshold
;
output
->
left_output
=
CalculateSplittedLeafOutput
(
sum_left_gradient
,
sum_left_hessian
,
meta_
->
tree_config
->
lambda_l1
,
meta_
->
tree_config
->
lambda_l2
,
meta_
->
tree_config
->
max_delta_step
);
output
->
left_count
=
left_count
;
output
->
left_sum_gradient
=
sum_left_gradient
;
output
->
left_sum_hessian
=
sum_left_hessian
-
kEpsilon
;
output
->
right_output
=
CalculateSplittedLeafOutput
(
sum_gradient
-
sum_left_gradient
,
sum_hessian
-
sum_left_hessian
,
meta_
->
tree_config
->
lambda_l1
,
meta_
->
tree_config
->
lambda_l2
,
meta_
->
tree_config
->
max_delta_step
);
output
->
right_count
=
num_data
-
left_count
;
output
->
right_sum_gradient
=
sum_gradient
-
sum_left_gradient
;
output
->
right_sum_hessian
=
sum_hessian
-
sum_left_hessian
-
kEpsilon
;
output
->
gain
=
current_gain
;
output
->
gain
-=
min_gain_shift
;
output
->
default_left
=
true
;
}
void
GatherInfoForThresholdCategorical
(
double
sum_gradient
,
double
sum_hessian
,
uint32_t
threshold
,
data_size_t
num_data
,
SplitInfo
*
output
)
{
// get SplitInfo for a given one-hot categorical split.
output
->
default_left
=
false
;
double
gain_shift
=
GetLeafSplitGain
(
sum_gradient
,
sum_hessian
,
meta_
->
tree_config
->
lambda_l1
,
meta_
->
tree_config
->
lambda_l2
,
meta_
->
tree_config
->
max_delta_step
);
double
min_gain_shift
=
gain_shift
+
meta_
->
tree_config
->
min_gain_to_split
;
bool
is_full_categorical
=
meta_
->
missing_type
==
MissingType
::
None
;
int
used_bin
=
meta_
->
num_bin
-
1
+
is_full_categorical
;
if
(
threshold
>=
static_cast
<
uint32_t
>
(
used_bin
))
{
output
->
gain
=
kMinScore
;
Log
::
Warning
(
"Invalid categorical threshold split"
);
return
;
}
double
l2
=
meta_
->
tree_config
->
lambda_l2
;
data_size_t
left_count
=
data_
[
threshold
].
cnt
;
data_size_t
right_count
=
num_data
-
left_count
;
double
sum_left_hessian
=
data_
[
threshold
].
sum_hessians
+
kEpsilon
;
double
sum_right_hessian
=
sum_hessian
-
sum_left_hessian
;
double
sum_left_gradient
=
data_
[
threshold
].
sum_gradients
;
double
sum_right_gradient
=
sum_gradient
-
sum_left_gradient
;
// current split gain
double
current_gain
=
GetLeafSplitGain
(
sum_right_gradient
,
sum_right_hessian
,
meta_
->
tree_config
->
lambda_l1
,
l2
,
meta_
->
tree_config
->
max_delta_step
)
+
GetLeafSplitGain
(
sum_left_gradient
,
sum_right_hessian
,
meta_
->
tree_config
->
lambda_l1
,
l2
,
meta_
->
tree_config
->
max_delta_step
);
if
(
std
::
isnan
(
current_gain
)
||
current_gain
<=
min_gain_shift
)
{
output
->
gain
=
kMinScore
;
Log
::
Warning
(
"Gain with forced split worse than without split"
);
return
;
}
output
->
left_output
=
CalculateSplittedLeafOutput
(
sum_left_gradient
,
sum_left_hessian
,
meta_
->
tree_config
->
lambda_l1
,
l2
,
meta_
->
tree_config
->
max_delta_step
);
output
->
left_count
=
left_count
;
output
->
left_sum_gradient
=
sum_left_gradient
;
output
->
left_sum_hessian
=
sum_left_hessian
-
kEpsilon
;
output
->
right_output
=
CalculateSplittedLeafOutput
(
sum_right_gradient
,
sum_right_hessian
,
meta_
->
tree_config
->
lambda_l1
,
l2
,
meta_
->
tree_config
->
max_delta_step
);
output
->
right_count
=
right_count
;
output
->
right_sum_gradient
=
sum_gradient
-
sum_left_gradient
;
output
->
right_sum_hessian
=
sum_right_hessian
-
kEpsilon
;
output
->
gain
=
current_gain
-
min_gain_shift
;
output
->
num_cat_threshold
=
1
;
output
->
cat_threshold
=
std
::
vector
<
uint32_t
>
(
1
,
threshold
);
}
/*!
* \brief Binary size of this histogram
*/
...
...
@@ -500,7 +646,6 @@ private:
/*! \brief sum of gradient of each bin */
HistogramBinEntry
*
data_
;
//std::vector<HistogramBinEntry> data_;
/*! \brief False if this histogram cannot split */
bool
is_splittable_
=
true
;
std
::
function
<
void
(
double
,
double
,
data_size_t
,
double
,
double
,
SplitInfo
*
)
>
find_best_threshold_fun_
;
...
...
@@ -568,6 +713,7 @@ public:
feature_metas_
[
i
].
bias
=
0
;
}
feature_metas_
[
i
].
tree_config
=
tree_config
;
feature_metas_
[
i
].
bin_type
=
train_data
->
FeatureBinMapper
(
i
)
->
bin_type
();
}
}
uint64_t
num_total_bin
=
train_data
->
NumTotalBin
();
...
...
@@ -589,7 +735,7 @@ public:
uint64_t
offset
=
0
;
for
(
int
j
=
0
;
j
<
train_data
->
num_features
();
++
j
)
{
offset
+=
static_cast
<
uint64_t
>
(
train_data
->
SubFeatureBinOffset
(
j
));
pool_
[
i
][
j
].
Init
(
data_
[
i
].
data
()
+
offset
,
&
feature_metas_
[
j
]
,
train_data
->
FeatureBinMapper
(
j
)
->
bin_type
()
);
pool_
[
i
][
j
].
Init
(
data_
[
i
].
data
()
+
offset
,
&
feature_metas_
[
j
]);
auto
num_bin
=
train_data
->
FeatureNumBin
(
j
);
if
(
train_data
->
FeatureBinMapper
(
j
)
->
GetDefaultBin
()
==
0
)
{
num_bin
-=
1
;
...
...
src/treelearner/gpu_tree_learner.cpp
View file @
84fef715
...
...
@@ -751,7 +751,8 @@ void GPUTreeLearner::InitGPU(int platform_id, int device_id) {
SetupKernelArguments
();
}
Tree
*
GPUTreeLearner
::
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
)
{
Tree
*
GPUTreeLearner
::
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
,
Json
&
forced_split_json
)
{
// check if we need to recompile the GPU kernel (is_constant_hessian changed)
// this should rarely occur
if
(
is_constant_hessian
!=
is_constant_hessian_
)
{
...
...
@@ -760,7 +761,7 @@ Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians, b
BuildGPUKernels
();
SetupKernelArguments
();
}
return
SerialTreeLearner
::
Train
(
gradients
,
hessians
,
is_constant_hessian
);
return
SerialTreeLearner
::
Train
(
gradients
,
hessians
,
is_constant_hessian
,
forced_split_json
);
}
void
GPUTreeLearner
::
ResetTrainingData
(
const
Dataset
*
train_data
)
{
...
...
src/treelearner/gpu_tree_learner.h
View file @
84fef715
...
...
@@ -28,6 +28,7 @@
#include <boost/compute/container/vector.hpp>
#include <boost/align/aligned_allocator.hpp>
using
namespace
json11
;
namespace
LightGBM
{
...
...
@@ -40,7 +41,8 @@ public:
~
GPUTreeLearner
();
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
ResetTrainingData
(
const
Dataset
*
train_data
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
,
Json
&
forced_split_json
)
override
;
void
SetBaggingData
(
const
data_size_t
*
used_indices
,
data_size_t
num_data
)
override
{
SerialTreeLearner
::
SetBaggingData
(
used_indices
,
num_data
);
...
...
src/treelearner/serial_tree_learner.cpp
View file @
84fef715
...
...
@@ -6,6 +6,7 @@
#include <algorithm>
#include <vector>
#include <queue>
namespace
LightGBM
{
...
...
@@ -152,7 +153,7 @@ void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
histogram_pool_
.
ResetConfig
(
tree_config_
);
}
Tree
*
SerialTreeLearner
::
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
)
{
Tree
*
SerialTreeLearner
::
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
,
Json
&
forced_split_json
)
{
gradients_
=
gradients
;
hessians_
=
hessians
;
is_constant_hessian_
=
is_constant_hessian
;
...
...
@@ -172,18 +173,29 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
int
cur_depth
=
1
;
// only root leaf can be splitted on first time
int
right_leaf
=
-
1
;
for
(
int
split
=
0
;
split
<
tree_config_
->
num_leaves
-
1
;
++
split
)
{
int
init_splits
=
0
;
bool
aborted_last_force_split
=
false
;
if
(
!
forced_split_json
.
is_null
())
{
init_splits
=
ForceSplits
(
tree
.
get
(),
forced_split_json
,
&
left_leaf
,
&
right_leaf
,
&
cur_depth
,
&
aborted_last_force_split
);
}
for
(
int
split
=
init_splits
;
split
<
tree_config_
->
num_leaves
-
1
;
++
split
)
{
#ifdef TIMETAG
start_time
=
std
::
chrono
::
steady_clock
::
now
();
#endif
// some initial works before finding best split
if
(
BeforeFindBestSplit
(
tree
.
get
(),
left_leaf
,
right_leaf
))
{
if
(
!
aborted_last_force_split
&&
BeforeFindBestSplit
(
tree
.
get
(),
left_leaf
,
right_leaf
))
{
#ifdef TIMETAG
init_split_time
+=
std
::
chrono
::
steady_clock
::
now
()
-
start_time
;
#endif
// find best threshold for every feature
FindBestSplits
();
}
else
if
(
aborted_last_force_split
)
{
aborted_last_force_split
=
false
;
}
// Get a leaf with max split gain
int
best_leaf
=
static_cast
<
int
>
(
ArrayArgs
<
SplitInfo
>::
ArgMax
(
best_split_per_leaf_
));
// Get split information for best leaf
...
...
@@ -528,6 +540,162 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
#endif
}
int32_t
SerialTreeLearner
::
ForceSplits
(
Tree
*
tree
,
Json
&
forced_split_json
,
int
*
left_leaf
,
int
*
right_leaf
,
int
*
cur_depth
,
bool
*
aborted_last_force_split
)
{
int32_t
result_count
=
0
;
// start at root leaf
*
left_leaf
=
0
;
std
::
queue
<
std
::
pair
<
Json
,
int
>>
q
;
Json
left
=
forced_split_json
;
Json
right
;
bool
left_smaller
=
true
;
std
::
unordered_map
<
int
,
SplitInfo
>
forceSplitMap
;
q
.
push
(
std
::
make_pair
(
forced_split_json
,
*
left_leaf
));
while
(
!
q
.
empty
())
{
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if
(
BeforeFindBestSplit
(
tree
,
*
left_leaf
,
*
right_leaf
))
{
FindBestSplits
();
}
// then, compute own splits
SplitInfo
left_split
;
SplitInfo
right_split
;
if
(
!
left
.
is_null
())
{
const
int
left_feature
=
left
[
"feature"
].
int_value
();
const
double
left_threshold_double
=
left
[
"threshold"
].
number_value
();
const
int
left_inner_feature_index
=
train_data_
->
InnerFeatureIndex
(
left_feature
);
const
uint32_t
left_threshold
=
train_data_
->
BinThreshold
(
left_inner_feature_index
,
left_threshold_double
);
auto
leaf_histogram_array
=
(
left_smaller
)
?
smaller_leaf_histogram_array_
:
larger_leaf_histogram_array_
;
auto
left_leaf_splits
=
(
left_smaller
)
?
smaller_leaf_splits_
.
get
()
:
larger_leaf_splits_
.
get
();
leaf_histogram_array
[
left_inner_feature_index
].
GatherInfoForThreshold
(
left_leaf_splits
->
sum_gradients
(),
left_leaf_splits
->
sum_hessians
(),
left_threshold
,
left_leaf_splits
->
num_data_in_leaf
(),
&
left_split
);
left_split
.
feature
=
left_feature
;
forceSplitMap
[
*
left_leaf
]
=
left_split
;
if
(
left_split
.
gain
<
0
)
{
forceSplitMap
.
erase
(
*
left_leaf
);
}
}
if
(
!
right
.
is_null
())
{
const
int
right_feature
=
right
[
"feature"
].
int_value
();
const
double
right_threshold_double
=
right
[
"threshold"
].
number_value
();
const
int
right_inner_feature_index
=
train_data_
->
InnerFeatureIndex
(
right_feature
);
const
uint32_t
right_threshold
=
train_data_
->
BinThreshold
(
right_inner_feature_index
,
right_threshold_double
);
auto
leaf_histogram_array
=
(
left_smaller
)
?
larger_leaf_histogram_array_
:
smaller_leaf_histogram_array_
;
auto
right_leaf_splits
=
(
left_smaller
)
?
larger_leaf_splits_
.
get
()
:
smaller_leaf_splits_
.
get
();
leaf_histogram_array
[
right_inner_feature_index
].
GatherInfoForThreshold
(
right_leaf_splits
->
sum_gradients
(),
right_leaf_splits
->
sum_hessians
(),
right_threshold
,
right_leaf_splits
->
num_data_in_leaf
(),
&
right_split
);
right_split
.
feature
=
right_feature
;
forceSplitMap
[
*
right_leaf
]
=
right_split
;
if
(
right_split
.
gain
<
0
)
{
forceSplitMap
.
erase
(
*
right_leaf
);
}
}
std
::
pair
<
Json
,
int
>
pair
=
q
.
front
();
q
.
pop
();
int
current_leaf
=
pair
.
second
;
// split info should exist because searching in bfs fashion - should have added from parent
if
(
forceSplitMap
.
find
(
current_leaf
)
==
forceSplitMap
.
end
())
{
*
aborted_last_force_split
=
true
;
break
;
}
SplitInfo
current_split_info
=
forceSplitMap
[
current_leaf
];
const
int
inner_feature_index
=
train_data_
->
InnerFeatureIndex
(
current_split_info
.
feature
);
auto
threshold_double
=
train_data_
->
RealThreshold
(
inner_feature_index
,
current_split_info
.
threshold
);
// split tree, will return right leaf
*
left_leaf
=
current_leaf
;
if
(
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
bin_type
()
==
BinType
::
NumericalBin
)
{
*
right_leaf
=
tree
->
Split
(
current_leaf
,
inner_feature_index
,
current_split_info
.
feature
,
current_split_info
.
threshold
,
threshold_double
,
static_cast
<
double
>
(
current_split_info
.
left_output
),
static_cast
<
double
>
(
current_split_info
.
right_output
),
static_cast
<
data_size_t
>
(
current_split_info
.
left_count
),
static_cast
<
data_size_t
>
(
current_split_info
.
right_count
),
static_cast
<
float
>
(
current_split_info
.
gain
),
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
(),
current_split_info
.
default_left
);
data_partition_
->
Split
(
current_leaf
,
train_data_
,
inner_feature_index
,
&
current_split_info
.
threshold
,
1
,
current_split_info
.
default_left
,
*
right_leaf
);
}
else
{
std
::
vector
<
uint32_t
>
cat_bitset_inner
=
Common
::
ConstructBitset
(
current_split_info
.
cat_threshold
.
data
(),
current_split_info
.
num_cat_threshold
);
std
::
vector
<
int
>
threshold_int
(
current_split_info
.
num_cat_threshold
);
for
(
int
i
=
0
;
i
<
current_split_info
.
num_cat_threshold
;
++
i
)
{
threshold_int
[
i
]
=
static_cast
<
int
>
(
train_data_
->
RealThreshold
(
inner_feature_index
,
current_split_info
.
cat_threshold
[
i
]));
}
std
::
vector
<
uint32_t
>
cat_bitset
=
Common
::
ConstructBitset
(
threshold_int
.
data
(),
current_split_info
.
num_cat_threshold
);
*
right_leaf
=
tree
->
SplitCategorical
(
current_leaf
,
inner_feature_index
,
current_split_info
.
feature
,
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
cat_bitset
.
data
(),
static_cast
<
int
>
(
cat_bitset
.
size
()),
static_cast
<
double
>
(
current_split_info
.
left_output
),
static_cast
<
double
>
(
current_split_info
.
right_output
),
static_cast
<
data_size_t
>
(
current_split_info
.
left_count
),
static_cast
<
data_size_t
>
(
current_split_info
.
right_count
),
static_cast
<
float
>
(
current_split_info
.
gain
),
train_data_
->
FeatureBinMapper
(
inner_feature_index
)
->
missing_type
());
data_partition_
->
Split
(
current_leaf
,
train_data_
,
inner_feature_index
,
cat_bitset_inner
.
data
(),
static_cast
<
int
>
(
cat_bitset_inner
.
size
()),
current_split_info
.
default_left
,
*
right_leaf
);
}
if
(
current_split_info
.
left_count
<
current_split_info
.
right_count
)
{
left_smaller
=
true
;
smaller_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
current_split_info
.
left_sum_gradient
,
current_split_info
.
left_sum_hessian
);
larger_leaf_splits_
->
Init
(
*
right_leaf
,
data_partition_
.
get
(),
current_split_info
.
right_sum_gradient
,
current_split_info
.
right_sum_hessian
);
}
else
{
left_smaller
=
false
;
smaller_leaf_splits_
->
Init
(
*
right_leaf
,
data_partition_
.
get
(),
current_split_info
.
right_sum_gradient
,
current_split_info
.
right_sum_hessian
);
larger_leaf_splits_
->
Init
(
*
left_leaf
,
data_partition_
.
get
(),
current_split_info
.
left_sum_gradient
,
current_split_info
.
left_sum_hessian
);
}
left
=
Json
();
right
=
Json
();
if
((
pair
.
first
).
object_items
().
count
(
"left"
)
>
0
)
{
left
=
(
pair
.
first
)[
"left"
];
q
.
push
(
std
::
make_pair
(
left
,
*
left_leaf
));
}
if
((
pair
.
first
).
object_items
().
count
(
"right"
)
>
0
)
{
right
=
(
pair
.
first
)[
"right"
];
q
.
push
(
std
::
make_pair
(
right
,
*
right_leaf
));
}
result_count
++
;
*
(
cur_depth
)
=
std
::
max
(
*
(
cur_depth
),
tree
->
leaf_depth
(
*
left_leaf
));
}
return
result_count
;
}
void
SerialTreeLearner
::
Split
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
{
const
SplitInfo
&
best_split_info
=
best_split_per_leaf_
[
best_leaf
];
...
...
src/treelearner/serial_tree_learner.h
View file @
84fef715
...
...
@@ -24,6 +24,8 @@
#include <boost/align/aligned_allocator.hpp>
#endif
using
namespace
json11
;
namespace
LightGBM
{
/*!
...
...
@@ -41,7 +43,8 @@ public:
void
ResetConfig
(
const
TreeConfig
*
tree_config
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
)
override
;
Tree
*
Train
(
const
score_t
*
gradients
,
const
score_t
*
hessians
,
bool
is_constant_hessian
,
Json
&
forced_split_json
)
override
;
Tree
*
FitByExistingTree
(
const
Tree
*
old_tree
,
const
score_t
*
gradients
,
const
score_t
*
hessians
)
const
override
;
...
...
@@ -95,6 +98,12 @@ protected:
*/
virtual
void
Split
(
Tree
*
tree
,
int
best_leaf
,
int
*
left_leaf
,
int
*
right_leaf
);
/* Force splits with forced_split_json dict and then return num splits forced.*/
virtual
int32_t
ForceSplits
(
Tree
*
tree
,
Json
&
forced_split_json
,
int
*
left_leaf
,
int
*
right_leaf
,
int
*
cur_depth
,
bool
*
aborted_last_force_split
);
/*!
* \brief Get the number of data in a leaf
* \param leaf_idx The index of leaf
...
...
src/treelearner/voting_parallel_tree_learner.cpp
View file @
84fef715
...
...
@@ -76,12 +76,13 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
feature_metas_
[
i
].
bias
=
0
;
}
feature_metas_
[
i
].
tree_config
=
this
->
tree_config_
;
feature_metas_
[
i
].
bin_type
=
train_data
->
FeatureBinMapper
(
i
)
->
bin_type
();
}
uint64_t
offset
=
0
;
for
(
int
j
=
0
;
j
<
train_data
->
num_features
();
++
j
)
{
offset
+=
static_cast
<
uint64_t
>
(
train_data
->
SubFeatureBinOffset
(
j
));
smaller_leaf_histogram_array_global_
[
j
].
Init
(
smaller_leaf_histogram_data_
.
data
()
+
offset
,
&
feature_metas_
[
j
]
,
train_data
->
FeatureBinMapper
(
j
)
->
bin_type
()
);
larger_leaf_histogram_array_global_
[
j
].
Init
(
larger_leaf_histogram_data_
.
data
()
+
offset
,
&
feature_metas_
[
j
]
,
train_data
->
FeatureBinMapper
(
j
)
->
bin_type
()
);
smaller_leaf_histogram_array_global_
[
j
].
Init
(
smaller_leaf_histogram_data_
.
data
()
+
offset
,
&
feature_metas_
[
j
]);
larger_leaf_histogram_array_global_
[
j
].
Init
(
larger_leaf_histogram_data_
.
data
()
+
offset
,
&
feature_metas_
[
j
]);
auto
num_bin
=
train_data
->
FeatureNumBin
(
j
);
if
(
train_data
->
FeatureBinMapper
(
j
)
->
GetDefaultBin
()
==
0
)
{
num_bin
-=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment