Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
462612b4
"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "c27ebcd8a996b06689dfcdfaca98aa55b51893cd"
Commit
462612b4
authored
Feb 06, 2019
by
Nikita Titov
Committed by
Guolin Ke
Feb 06, 2019
Browse files
fixed modifiers indent (#1997)
parent
8e286b38
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
58 additions
and
57 deletions
+58
-57
src/network/linkers.h
src/network/linkers.h
+2
-2
src/network/socket_wrapper.hpp
src/network/socket_wrapper.hpp
+2
-2
src/objective/binary_objective.hpp
src/objective/binary_objective.hpp
+2
-2
src/objective/multiclass_objective.hpp
src/objective/multiclass_objective.hpp
+4
-4
src/objective/rank_objective.hpp
src/objective/rank_objective.hpp
+2
-2
src/objective/regression_objective.hpp
src/objective/regression_objective.hpp
+16
-16
src/objective/xentropy_objective.hpp
src/objective/xentropy_objective.hpp
+3
-3
src/treelearner/data_partition.hpp
src/treelearner/data_partition.hpp
+2
-2
src/treelearner/feature_histogram.hpp
src/treelearner/feature_histogram.hpp
+5
-5
src/treelearner/gpu_tree_learner.h
src/treelearner/gpu_tree_learner.h
+4
-4
src/treelearner/leaf_splits.hpp
src/treelearner/leaf_splits.hpp
+2
-2
src/treelearner/parallel_tree_learner.h
src/treelearner/parallel_tree_learner.h
+10
-9
src/treelearner/serial_tree_learner.h
src/treelearner/serial_tree_learner.h
+2
-2
src/treelearner/split_info.hpp
src/treelearner/split_info.hpp
+2
-2
No files found.
src/network/linkers.h
View file @
462612b4
...
@@ -32,7 +32,7 @@ namespace LightGBM {
...
@@ -32,7 +32,7 @@ namespace LightGBM {
* This class will wrap all linkers to other machines if needs
* This class will wrap all linkers to other machines if needs
*/
*/
class
Linkers
{
class
Linkers
{
public:
public:
Linkers
()
{
Linkers
()
{
is_init_
=
false
;
is_init_
=
false
;
}
}
...
@@ -136,7 +136,7 @@ public:
...
@@ -136,7 +136,7 @@ public:
#endif // USE_SOCKET
#endif // USE_SOCKET
private:
private:
/*! \brief Rank of local machine */
/*! \brief Rank of local machine */
int
rank_
;
int
rank_
;
/*! \brief Total number machines */
/*! \brief Total number machines */
...
...
src/network/socket_wrapper.hpp
View file @
462612b4
...
@@ -86,7 +86,7 @@ const bool kNoDelay = true;
...
@@ -86,7 +86,7 @@ const bool kNoDelay = true;
}
}
class
TcpSocket
{
class
TcpSocket
{
public:
public:
TcpSocket
()
{
TcpSocket
()
{
sockfd_
=
socket
(
AF_INET
,
SOCK_STREAM
,
IPPROTO_TCP
);
sockfd_
=
socket
(
AF_INET
,
SOCK_STREAM
,
IPPROTO_TCP
);
if
(
sockfd_
==
INVALID_SOCKET
)
{
if
(
sockfd_
==
INVALID_SOCKET
)
{
...
@@ -291,7 +291,7 @@ public:
...
@@ -291,7 +291,7 @@ public:
}
}
}
}
private:
private:
SOCKET
sockfd_
;
SOCKET
sockfd_
;
};
};
...
...
src/objective/binary_objective.hpp
View file @
462612b4
...
@@ -11,7 +11,7 @@ namespace LightGBM {
...
@@ -11,7 +11,7 @@ namespace LightGBM {
* \brief Objective function for binary classification
* \brief Objective function for binary classification
*/
*/
class
BinaryLogloss
:
public
ObjectiveFunction
{
class
BinaryLogloss
:
public
ObjectiveFunction
{
public:
public:
explicit
BinaryLogloss
(
const
Config
&
config
,
std
::
function
<
bool
(
label_t
)
>
is_pos
=
nullptr
)
{
explicit
BinaryLogloss
(
const
Config
&
config
,
std
::
function
<
bool
(
label_t
)
>
is_pos
=
nullptr
)
{
sigmoid_
=
static_cast
<
double
>
(
config
.
sigmoid
);
sigmoid_
=
static_cast
<
double
>
(
config
.
sigmoid
);
if
(
sigmoid_
<=
0.0
)
{
if
(
sigmoid_
<=
0.0
)
{
...
@@ -172,7 +172,7 @@ public:
...
@@ -172,7 +172,7 @@ public:
bool
NeedAccuratePrediction
()
const
override
{
return
false
;
}
bool
NeedAccuratePrediction
()
const
override
{
return
false
;
}
private:
private:
/*! \brief Number of data */
/*! \brief Number of data */
data_size_t
num_data_
;
data_size_t
num_data_
;
/*! \brief Pointer of label */
/*! \brief Pointer of label */
...
...
src/objective/multiclass_objective.hpp
View file @
462612b4
...
@@ -14,7 +14,7 @@ namespace LightGBM {
...
@@ -14,7 +14,7 @@ namespace LightGBM {
* \brief Objective function for multiclass classification, use softmax as objective functions
* \brief Objective function for multiclass classification, use softmax as objective functions
*/
*/
class
MulticlassSoftmax
:
public
ObjectiveFunction
{
class
MulticlassSoftmax
:
public
ObjectiveFunction
{
public:
public:
explicit
MulticlassSoftmax
(
const
Config
&
config
)
{
explicit
MulticlassSoftmax
(
const
Config
&
config
)
{
num_class_
=
config
.
num_class
;
num_class_
=
config
.
num_class
;
}
}
...
@@ -146,7 +146,7 @@ public:
...
@@ -146,7 +146,7 @@ public:
}
}
}
}
private:
private:
/*! \brief Number of data */
/*! \brief Number of data */
data_size_t
num_data_
;
data_size_t
num_data_
;
/*! \brief Number of classes */
/*! \brief Number of classes */
...
@@ -164,7 +164,7 @@ private:
...
@@ -164,7 +164,7 @@ private:
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
*/
class
MulticlassOVA
:
public
ObjectiveFunction
{
class
MulticlassOVA
:
public
ObjectiveFunction
{
public:
public:
explicit
MulticlassOVA
(
const
Config
&
config
)
{
explicit
MulticlassOVA
(
const
Config
&
config
)
{
num_class_
=
config
.
num_class
;
num_class_
=
config
.
num_class
;
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
...
@@ -246,7 +246,7 @@ public:
...
@@ -246,7 +246,7 @@ public:
return
binary_loss_
[
class_id
]
->
ClassNeedTrain
(
0
);
return
binary_loss_
[
class_id
]
->
ClassNeedTrain
(
0
);
}
}
private:
private:
/*! \brief Number of data */
/*! \brief Number of data */
data_size_t
num_data_
;
data_size_t
num_data_
;
/*! \brief Number of classes */
/*! \brief Number of classes */
...
...
src/objective/rank_objective.hpp
View file @
462612b4
...
@@ -17,7 +17,7 @@ namespace LightGBM {
...
@@ -17,7 +17,7 @@ namespace LightGBM {
* \brief Objective function for Lambdrank with NDCG
* \brief Objective function for Lambdrank with NDCG
*/
*/
class
LambdarankNDCG
:
public
ObjectiveFunction
{
class
LambdarankNDCG
:
public
ObjectiveFunction
{
public:
public:
explicit
LambdarankNDCG
(
const
Config
&
config
)
{
explicit
LambdarankNDCG
(
const
Config
&
config
)
{
sigmoid_
=
static_cast
<
double
>
(
config
.
sigmoid
);
sigmoid_
=
static_cast
<
double
>
(
config
.
sigmoid
);
label_gain_
=
config
.
label_gain
;
label_gain_
=
config
.
label_gain
;
...
@@ -205,7 +205,7 @@ public:
...
@@ -205,7 +205,7 @@ public:
bool
NeedAccuratePrediction
()
const
override
{
return
false
;
}
bool
NeedAccuratePrediction
()
const
override
{
return
false
;
}
private:
private:
/*! \brief Gains for labels */
/*! \brief Gains for labels */
std
::
vector
<
double
>
label_gain_
;
std
::
vector
<
double
>
label_gain_
;
/*! \brief Cache inverse max DCG, speed up calculation */
/*! \brief Cache inverse max DCG, speed up calculation */
...
...
src/objective/regression_objective.hpp
View file @
462612b4
...
@@ -69,7 +69,7 @@ namespace LightGBM {
...
@@ -69,7 +69,7 @@ namespace LightGBM {
* \brief Objective function for regression
* \brief Objective function for regression
*/
*/
class
RegressionL2loss
:
public
ObjectiveFunction
{
class
RegressionL2loss
:
public
ObjectiveFunction
{
public:
public:
explicit
RegressionL2loss
(
const
Config
&
config
)
{
explicit
RegressionL2loss
(
const
Config
&
config
)
{
sqrt_
=
config
.
reg_sqrt
;
sqrt_
=
config
.
reg_sqrt
;
}
}
...
@@ -165,7 +165,7 @@ public:
...
@@ -165,7 +165,7 @@ public:
return
suml
/
sumw
;
return
suml
/
sumw
;
}
}
protected:
protected:
bool
sqrt_
;
bool
sqrt_
;
/*! \brief Number of data */
/*! \brief Number of data */
data_size_t
num_data_
;
data_size_t
num_data_
;
...
@@ -180,7 +180,7 @@ protected:
...
@@ -180,7 +180,7 @@ protected:
* \brief L1 regression loss
* \brief L1 regression loss
*/
*/
class
RegressionL1loss
:
public
RegressionL2loss
{
class
RegressionL1loss
:
public
RegressionL2loss
{
public:
public:
explicit
RegressionL1loss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
explicit
RegressionL1loss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
}
}
...
@@ -298,7 +298,7 @@ public:
...
@@ -298,7 +298,7 @@ public:
* \brief Huber regression loss
* \brief Huber regression loss
*/
*/
class
RegressionHuberLoss
:
public
RegressionL2loss
{
class
RegressionHuberLoss
:
public
RegressionL2loss
{
public:
public:
explicit
RegressionHuberLoss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
explicit
RegressionHuberLoss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
alpha_
=
static_cast
<
double
>
(
config
.
alpha
);
alpha_
=
static_cast
<
double
>
(
config
.
alpha
);
if
(
sqrt_
)
{
if
(
sqrt_
)
{
...
@@ -352,7 +352,7 @@ public:
...
@@ -352,7 +352,7 @@ public:
return
false
;
return
false
;
}
}
private:
private:
/*! \brief delta for Huber loss */
/*! \brief delta for Huber loss */
double
alpha_
;
double
alpha_
;
};
};
...
@@ -360,7 +360,7 @@ private:
...
@@ -360,7 +360,7 @@ private:
// http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html
// http://research.microsoft.com/en-us/um/people/zhang/INRIA/Publis/Tutorial-Estim/node24.html
class
RegressionFairLoss
:
public
RegressionL2loss
{
class
RegressionFairLoss
:
public
RegressionL2loss
{
public:
public:
explicit
RegressionFairLoss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
explicit
RegressionFairLoss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
c_
=
static_cast
<
double
>
(
config
.
fair_c
);
c_
=
static_cast
<
double
>
(
config
.
fair_c
);
}
}
...
@@ -397,7 +397,7 @@ public:
...
@@ -397,7 +397,7 @@ public:
return
false
;
return
false
;
}
}
private:
private:
/*! \brief c for Fair loss */
/*! \brief c for Fair loss */
double
c_
;
double
c_
;
};
};
...
@@ -407,7 +407,7 @@ private:
...
@@ -407,7 +407,7 @@ private:
* \brief Objective function for Poisson regression
* \brief Objective function for Poisson regression
*/
*/
class
RegressionPoissonLoss
:
public
RegressionL2loss
{
class
RegressionPoissonLoss
:
public
RegressionL2loss
{
public:
public:
explicit
RegressionPoissonLoss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
explicit
RegressionPoissonLoss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
max_delta_step_
=
static_cast
<
double
>
(
config
.
poisson_max_delta_step
);
max_delta_step_
=
static_cast
<
double
>
(
config
.
poisson_max_delta_step
);
if
(
sqrt_
)
{
if
(
sqrt_
)
{
...
@@ -481,13 +481,13 @@ public:
...
@@ -481,13 +481,13 @@ public:
return
false
;
return
false
;
}
}
private:
private:
/*! \brief used to safeguard optimization */
/*! \brief used to safeguard optimization */
double
max_delta_step_
;
double
max_delta_step_
;
};
};
class
RegressionQuantileloss
:
public
RegressionL2loss
{
class
RegressionQuantileloss
:
public
RegressionL2loss
{
public:
public:
explicit
RegressionQuantileloss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
explicit
RegressionQuantileloss
(
const
Config
&
config
)
:
RegressionL2loss
(
config
)
{
alpha_
=
static_cast
<
score_t
>
(
config
.
alpha
);
alpha_
=
static_cast
<
score_t
>
(
config
.
alpha
);
CHECK
(
alpha_
>
0
&&
alpha_
<
1
);
CHECK
(
alpha_
>
0
&&
alpha_
<
1
);
...
@@ -607,7 +607,7 @@ public:
...
@@ -607,7 +607,7 @@ public:
}
}
}
}
private:
private:
score_t
alpha_
;
score_t
alpha_
;
};
};
...
@@ -616,7 +616,7 @@ private:
...
@@ -616,7 +616,7 @@ private:
* \brief Mape Regression Loss
* \brief Mape Regression Loss
*/
*/
class
RegressionMAPELOSS
:
public
RegressionL1loss
{
class
RegressionMAPELOSS
:
public
RegressionL1loss
{
public:
public:
explicit
RegressionMAPELOSS
(
const
Config
&
config
)
:
RegressionL1loss
(
config
)
{
explicit
RegressionMAPELOSS
(
const
Config
&
config
)
:
RegressionL1loss
(
config
)
{
}
}
...
@@ -725,7 +725,7 @@ public:
...
@@ -725,7 +725,7 @@ public:
return
true
;
return
true
;
}
}
private:
private:
std
::
vector
<
label_t
>
label_weight_
;
std
::
vector
<
label_t
>
label_weight_
;
};
};
...
@@ -735,7 +735,7 @@ private:
...
@@ -735,7 +735,7 @@ private:
* \brief Objective function for Gamma regression
* \brief Objective function for Gamma regression
*/
*/
class
RegressionGammaLoss
:
public
RegressionPoissonLoss
{
class
RegressionGammaLoss
:
public
RegressionPoissonLoss
{
public:
public:
explicit
RegressionGammaLoss
(
const
Config
&
config
)
:
RegressionPoissonLoss
(
config
)
{
explicit
RegressionGammaLoss
(
const
Config
&
config
)
:
RegressionPoissonLoss
(
config
)
{
}
}
...
@@ -770,7 +770,7 @@ public:
...
@@ -770,7 +770,7 @@ public:
* \brief Objective function for Tweedie regression
* \brief Objective function for Tweedie regression
*/
*/
class
RegressionTweedieLoss
:
public
RegressionPoissonLoss
{
class
RegressionTweedieLoss
:
public
RegressionPoissonLoss
{
public:
public:
explicit
RegressionTweedieLoss
(
const
Config
&
config
)
:
RegressionPoissonLoss
(
config
)
{
explicit
RegressionTweedieLoss
(
const
Config
&
config
)
:
RegressionPoissonLoss
(
config
)
{
rho_
=
config
.
tweedie_variance_power
;
rho_
=
config
.
tweedie_variance_power
;
}
}
...
@@ -803,7 +803,7 @@ public:
...
@@ -803,7 +803,7 @@ public:
return
"tweedie"
;
return
"tweedie"
;
}
}
private:
private:
double
rho_
;
double
rho_
;
};
};
...
...
src/objective/xentropy_objective.hpp
View file @
462612b4
...
@@ -36,7 +36,7 @@ namespace LightGBM {
...
@@ -36,7 +36,7 @@ namespace LightGBM {
* \brief Objective function for cross-entropy (with optional linear weights)
* \brief Objective function for cross-entropy (with optional linear weights)
*/
*/
class
CrossEntropy
:
public
ObjectiveFunction
{
class
CrossEntropy
:
public
ObjectiveFunction
{
public:
public:
explicit
CrossEntropy
(
const
Config
&
)
{
explicit
CrossEntropy
(
const
Config
&
)
{
}
}
...
@@ -127,7 +127,7 @@ public:
...
@@ -127,7 +127,7 @@ public:
return
initscore
;
return
initscore
;
}
}
private:
private:
/*! \brief Number of data points */
/*! \brief Number of data points */
data_size_t
num_data_
;
data_size_t
num_data_
;
/*! \brief Pointer for label */
/*! \brief Pointer for label */
...
@@ -140,7 +140,7 @@ private:
...
@@ -140,7 +140,7 @@ private:
* \brief Objective function for alternative parameterization of cross-entropy (see top of file for explanation)
* \brief Objective function for alternative parameterization of cross-entropy (see top of file for explanation)
*/
*/
class
CrossEntropyLambda
:
public
ObjectiveFunction
{
class
CrossEntropyLambda
:
public
ObjectiveFunction
{
public:
public:
explicit
CrossEntropyLambda
(
const
Config
&
)
{
explicit
CrossEntropyLambda
(
const
Config
&
)
{
min_weight_
=
max_weight_
=
0.0
f
;
min_weight_
=
max_weight_
=
0.0
f
;
}
}
...
...
src/treelearner/data_partition.hpp
View file @
462612b4
...
@@ -15,7 +15,7 @@ namespace LightGBM {
...
@@ -15,7 +15,7 @@ namespace LightGBM {
* \brief DataPartition is used to store the the partition of data on tree.
* \brief DataPartition is used to store the the partition of data on tree.
*/
*/
class
DataPartition
{
class
DataPartition
{
public:
public:
DataPartition
(
data_size_t
num_data
,
int
num_leaves
)
DataPartition
(
data_size_t
num_data
,
int
num_leaves
)
:
num_data_
(
num_data
),
num_leaves_
(
num_leaves
)
{
:
num_data_
(
num_data
),
num_leaves_
(
num_leaves
)
{
leaf_begin_
.
resize
(
num_leaves_
);
leaf_begin_
.
resize
(
num_leaves_
);
...
@@ -188,7 +188,7 @@ public:
...
@@ -188,7 +188,7 @@ public:
/*! \brief Get number of leaves */
/*! \brief Get number of leaves */
int
num_leaves
()
const
{
return
num_leaves_
;
}
int
num_leaves
()
const
{
return
num_leaves_
;
}
private:
private:
/*! \brief Number of all data */
/*! \brief Number of all data */
data_size_t
num_data_
;
data_size_t
num_data_
;
/*! \brief Number of all leaves */
/*! \brief Number of all leaves */
...
...
src/treelearner/feature_histogram.hpp
View file @
462612b4
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
LightGBM
{
namespace
LightGBM
{
class
FeatureMetainfo
{
class
FeatureMetainfo
{
public:
public:
int
num_bin
;
int
num_bin
;
MissingType
missing_type
;
MissingType
missing_type
;
int8_t
bias
=
0
;
int8_t
bias
=
0
;
...
@@ -27,7 +27,7 @@ public:
...
@@ -27,7 +27,7 @@ public:
* \brief FeatureHistogram is used to construct and store a histogram for a feature.
* \brief FeatureHistogram is used to construct and store a histogram for a feature.
*/
*/
class
FeatureHistogram
{
class
FeatureHistogram
{
public:
public:
FeatureHistogram
()
{
FeatureHistogram
()
{
data_
=
nullptr
;
data_
=
nullptr
;
}
}
...
@@ -449,7 +449,7 @@ public:
...
@@ -449,7 +449,7 @@ public:
}
}
}
}
private:
private:
static
double
GetSplitGains
(
double
sum_left_gradients
,
double
sum_left_hessians
,
static
double
GetSplitGains
(
double
sum_left_gradients
,
double
sum_left_hessians
,
double
sum_right_gradients
,
double
sum_right_hessians
,
double
sum_right_gradients
,
double
sum_right_hessians
,
double
l1
,
double
l2
,
double
max_delta_step
,
double
l1
,
double
l2
,
double
max_delta_step
,
...
@@ -644,7 +644,7 @@ private:
...
@@ -644,7 +644,7 @@ private:
std
::
function
<
void
(
double
,
double
,
data_size_t
,
double
,
double
,
SplitInfo
*
)
>
find_best_threshold_fun_
;
std
::
function
<
void
(
double
,
double
,
data_size_t
,
double
,
double
,
SplitInfo
*
)
>
find_best_threshold_fun_
;
};
};
class
HistogramPool
{
class
HistogramPool
{
public:
public:
/*!
/*!
* \brief Constructor
* \brief Constructor
*/
*/
...
@@ -804,7 +804,7 @@ public:
...
@@ -804,7 +804,7 @@ public:
inverse_mapper_
[
slot
]
=
dst_idx
;
inverse_mapper_
[
slot
]
=
dst_idx
;
}
}
private:
private:
std
::
vector
<
std
::
unique_ptr
<
FeatureHistogram
[]
>>
pool_
;
std
::
vector
<
std
::
unique_ptr
<
FeatureHistogram
[]
>>
pool_
;
std
::
vector
<
std
::
vector
<
HistogramBinEntry
>>
data_
;
std
::
vector
<
std
::
vector
<
HistogramBinEntry
>>
data_
;
std
::
vector
<
FeatureMetainfo
>
feature_metas_
;
std
::
vector
<
FeatureMetainfo
>
feature_metas_
;
...
...
src/treelearner/gpu_tree_learner.h
View file @
462612b4
...
@@ -36,7 +36,7 @@ namespace LightGBM {
...
@@ -36,7 +36,7 @@ namespace LightGBM {
* \brief GPU-based parallel learning algorithm.
* \brief GPU-based parallel learning algorithm.
*/
*/
class
GPUTreeLearner
:
public
SerialTreeLearner
{
class
GPUTreeLearner
:
public
SerialTreeLearner
{
public:
public:
explicit
GPUTreeLearner
(
const
Config
*
tree_config
);
explicit
GPUTreeLearner
(
const
Config
*
tree_config
);
~
GPUTreeLearner
();
~
GPUTreeLearner
();
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
...
@@ -57,14 +57,14 @@ public:
...
@@ -57,14 +57,14 @@ public:
use_bagging_
=
false
;
use_bagging_
=
false
;
}
}
protected:
protected:
void
BeforeTrain
()
override
;
void
BeforeTrain
()
override
;
bool
BeforeFindBestSplit
(
const
Tree
*
tree
,
int
left_leaf
,
int
right_leaf
)
override
;
bool
BeforeFindBestSplit
(
const
Tree
*
tree
,
int
left_leaf
,
int
right_leaf
)
override
;
void
FindBestSplits
()
override
;
void
FindBestSplits
()
override
;
void
Split
(
Tree
*
tree
,
int
best_Leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
override
;
void
Split
(
Tree
*
tree
,
int
best_Leaf
,
int
*
left_leaf
,
int
*
right_leaf
)
override
;
void
ConstructHistograms
(
const
std
::
vector
<
int8_t
>&
is_feature_used
,
bool
use_subtract
)
override
;
void
ConstructHistograms
(
const
std
::
vector
<
int8_t
>&
is_feature_used
,
bool
use_subtract
)
override
;
private:
private:
/*! \brief 4-byte feature tuple used by GPU kernels */
/*! \brief 4-byte feature tuple used by GPU kernels */
struct
Feature4
{
struct
Feature4
{
uint8_t
s
[
4
];
uint8_t
s
[
4
];
...
@@ -269,7 +269,7 @@ private:
...
@@ -269,7 +269,7 @@ private:
namespace
LightGBM
{
namespace
LightGBM
{
class
GPUTreeLearner
:
public
SerialTreeLearner
{
class
GPUTreeLearner
:
public
SerialTreeLearner
{
public:
public:
#pragma warning(disable : 4702)
#pragma warning(disable : 4702)
explicit
GPUTreeLearner
(
const
Config
*
tree_config
)
:
SerialTreeLearner
(
tree_config
)
{
explicit
GPUTreeLearner
(
const
Config
*
tree_config
)
:
SerialTreeLearner
(
tree_config
)
{
Log
::
Fatal
(
"GPU Tree Learner was not enabled in this build.
\n
"
Log
::
Fatal
(
"GPU Tree Learner was not enabled in this build.
\n
"
...
...
src/treelearner/leaf_splits.hpp
View file @
462612b4
...
@@ -14,7 +14,7 @@ namespace LightGBM {
...
@@ -14,7 +14,7 @@ namespace LightGBM {
* \brief used to find split candidates for a leaf
* \brief used to find split candidates for a leaf
*/
*/
class
LeafSplits
{
class
LeafSplits
{
public:
public:
LeafSplits
(
data_size_t
num_data
)
LeafSplits
(
data_size_t
num_data
)
:
num_data_in_leaf_
(
num_data
),
num_data_
(
num_data
),
:
num_data_in_leaf_
(
num_data
),
num_data_
(
num_data
),
data_indices_
(
nullptr
)
{
data_indices_
(
nullptr
)
{
...
@@ -141,7 +141,7 @@ public:
...
@@ -141,7 +141,7 @@ public:
const
data_size_t
*
data_indices
()
const
{
return
data_indices_
;
}
const
data_size_t
*
data_indices
()
const
{
return
data_indices_
;
}
private:
private:
/*! \brief current leaf index */
/*! \brief current leaf index */
int
leaf_index_
;
int
leaf_index_
;
/*! \brief number of data on current leaf */
/*! \brief number of data on current leaf */
...
...
src/treelearner/parallel_tree_learner.h
View file @
462612b4
...
@@ -20,15 +20,16 @@ namespace LightGBM {
...
@@ -20,15 +20,16 @@ namespace LightGBM {
*/
*/
template
<
typename
TREELEARNER_T
>
template
<
typename
TREELEARNER_T
>
class
FeatureParallelTreeLearner
:
public
TREELEARNER_T
{
class
FeatureParallelTreeLearner
:
public
TREELEARNER_T
{
public:
public:
explicit
FeatureParallelTreeLearner
(
const
Config
*
config
);
explicit
FeatureParallelTreeLearner
(
const
Config
*
config
);
~
FeatureParallelTreeLearner
();
~
FeatureParallelTreeLearner
();
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
protected:
protected:
void
BeforeTrain
()
override
;
void
BeforeTrain
()
override
;
void
FindBestSplitsFromHistograms
(
const
std
::
vector
<
int8_t
>&
is_feature_used
,
bool
use_subtract
)
override
;
void
FindBestSplitsFromHistograms
(
const
std
::
vector
<
int8_t
>&
is_feature_used
,
bool
use_subtract
)
override
;
private:
private:
/*! \brief rank of local machine */
/*! \brief rank of local machine */
int
rank_
;
int
rank_
;
/*! \brief Number of machines of this parallel task */
/*! \brief Number of machines of this parallel task */
...
@@ -46,13 +47,13 @@ private:
...
@@ -46,13 +47,13 @@ private:
*/
*/
template
<
typename
TREELEARNER_T
>
template
<
typename
TREELEARNER_T
>
class
DataParallelTreeLearner
:
public
TREELEARNER_T
{
class
DataParallelTreeLearner
:
public
TREELEARNER_T
{
public:
public:
explicit
DataParallelTreeLearner
(
const
Config
*
config
);
explicit
DataParallelTreeLearner
(
const
Config
*
config
);
~
DataParallelTreeLearner
();
~
DataParallelTreeLearner
();
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
ResetConfig
(
const
Config
*
config
)
override
;
void
ResetConfig
(
const
Config
*
config
)
override
;
protected:
protected:
void
BeforeTrain
()
override
;
void
BeforeTrain
()
override
;
void
FindBestSplits
()
override
;
void
FindBestSplits
()
override
;
void
FindBestSplitsFromHistograms
(
const
std
::
vector
<
int8_t
>&
is_feature_used
,
bool
use_subtract
)
override
;
void
FindBestSplitsFromHistograms
(
const
std
::
vector
<
int8_t
>&
is_feature_used
,
bool
use_subtract
)
override
;
...
@@ -66,7 +67,7 @@ protected:
...
@@ -66,7 +67,7 @@ protected:
}
}
}
}
private:
private:
/*! \brief Rank of local machine */
/*! \brief Rank of local machine */
int
rank_
;
int
rank_
;
/*! \brief Number of machines of this parallel task */
/*! \brief Number of machines of this parallel task */
...
@@ -100,13 +101,13 @@ private:
...
@@ -100,13 +101,13 @@ private:
*/
*/
template
<
typename
TREELEARNER_T
>
template
<
typename
TREELEARNER_T
>
class
VotingParallelTreeLearner
:
public
TREELEARNER_T
{
class
VotingParallelTreeLearner
:
public
TREELEARNER_T
{
public:
public:
explicit
VotingParallelTreeLearner
(
const
Config
*
config
);
explicit
VotingParallelTreeLearner
(
const
Config
*
config
);
~
VotingParallelTreeLearner
()
{
}
~
VotingParallelTreeLearner
()
{
}
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
Init
(
const
Dataset
*
train_data
,
bool
is_constant_hessian
)
override
;
void
ResetConfig
(
const
Config
*
config
)
override
;
void
ResetConfig
(
const
Config
*
config
)
override
;
protected:
protected:
void
BeforeTrain
()
override
;
void
BeforeTrain
()
override
;
bool
BeforeFindBestSplit
(
const
Tree
*
tree
,
int
left_leaf
,
int
right_leaf
)
override
;
bool
BeforeFindBestSplit
(
const
Tree
*
tree
,
int
left_leaf
,
int
right_leaf
)
override
;
void
FindBestSplits
()
override
;
void
FindBestSplits
()
override
;
...
@@ -136,7 +137,7 @@ protected:
...
@@ -136,7 +137,7 @@ protected:
void
CopyLocalHistogram
(
const
std
::
vector
<
int
>&
smaller_top_features
,
void
CopyLocalHistogram
(
const
std
::
vector
<
int
>&
smaller_top_features
,
const
std
::
vector
<
int
>&
larger_top_features
);
const
std
::
vector
<
int
>&
larger_top_features
);
private:
private:
/*! \brief Tree config used in local mode */
/*! \brief Tree config used in local mode */
Config
local_config_
;
Config
local_config_
;
/*! \brief Voting size */
/*! \brief Voting size */
...
...
src/treelearner/serial_tree_learner.h
View file @
462612b4
...
@@ -32,7 +32,7 @@ namespace LightGBM {
...
@@ -32,7 +32,7 @@ namespace LightGBM {
* \brief Used for learning a tree by single machine
* \brief Used for learning a tree by single machine
*/
*/
class
SerialTreeLearner
:
public
TreeLearner
{
class
SerialTreeLearner
:
public
TreeLearner
{
public:
public:
explicit
SerialTreeLearner
(
const
Config
*
config
);
explicit
SerialTreeLearner
(
const
Config
*
config
);
~
SerialTreeLearner
();
~
SerialTreeLearner
();
...
@@ -75,7 +75,7 @@ public:
...
@@ -75,7 +75,7 @@ public:
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
double
prediction
,
void
RenewTreeOutput
(
Tree
*
tree
,
const
ObjectiveFunction
*
obj
,
double
prediction
,
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
override
;
data_size_t
total_num_data
,
const
data_size_t
*
bag_indices
,
data_size_t
bag_cnt
)
const
override
;
protected:
protected:
/*!
/*!
* \brief Some initial works before training
* \brief Some initial works before training
*/
*/
...
...
src/treelearner/split_info.hpp
View file @
462612b4
...
@@ -15,7 +15,7 @@ namespace LightGBM {
...
@@ -15,7 +15,7 @@ namespace LightGBM {
* \brief Used to store some information for gain split point
* \brief Used to store some information for gain split point
*/
*/
struct
SplitInfo
{
struct
SplitInfo
{
public:
public:
/*! \brief Feature index */
/*! \brief Feature index */
int
feature
=
-
1
;
int
feature
=
-
1
;
/*! \brief Split threshold */
/*! \brief Split threshold */
...
@@ -188,7 +188,7 @@ public:
...
@@ -188,7 +188,7 @@ public:
};
};
struct
LightSplitInfo
{
struct
LightSplitInfo
{
public:
public:
/*! \brief Feature index */
/*! \brief Feature index */
int
feature
=
-
1
;
int
feature
=
-
1
;
/*! \brief Split gain */
/*! \brief Split gain */
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment