Commit 6c4a9750 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for the split of bins and leaves.

parent 8fb26b06
...@@ -12,20 +12,20 @@ ...@@ -12,20 +12,20 @@
namespace LightGBM { namespace LightGBM {
enum BinType { enum BinType {
NumericalBin, NumericalBin,
CategoricalBin CategoricalBin
}; };
enum MissingType { enum MissingType {
None, None,
Zero, Zero,
NaN NaN
}; };
/*! \brief Store data for one histogram bin */ /*! \brief Store data for one histogram bin */
struct HistogramBinEntry { struct HistogramBinEntry {
public: public:
/*! \brief Sum of gradients on this bin */ /*! \brief Sum of gradients on this bin */
double sum_gradients = 0.0f; double sum_gradients = 0.0f;
/*! \brief Sum of hessians on this bin */ /*! \brief Sum of hessians on this bin */
...@@ -53,12 +53,12 @@ namespace LightGBM { ...@@ -53,12 +53,12 @@ namespace LightGBM {
used_size += type_size; used_size += type_size;
} }
} }
}; };
/*! \brief This class used to convert feature values into bin, /*! \brief This class used to convert feature values into bin,
* and store some meta information for bin*/ * and store some meta information for bin*/
class BinMapper { class BinMapper {
public: public:
BinMapper(); BinMapper();
BinMapper(const BinMapper& other); BinMapper(const BinMapper& other);
explicit BinMapper(const void* memory); explicit BinMapper(const void* memory);
...@@ -183,7 +183,7 @@ namespace LightGBM { ...@@ -183,7 +183,7 @@ namespace LightGBM {
} }
} }
private: private:
/*! \brief Number of bins */ /*! \brief Number of bins */
int num_bin_; int num_bin_;
MissingType missing_type_; MissingType missing_type_;
...@@ -205,18 +205,18 @@ namespace LightGBM { ...@@ -205,18 +205,18 @@ namespace LightGBM {
double max_val_; double max_val_;
/*! \brief bin value of feature value 0 */ /*! \brief bin value of feature value 0 */
uint32_t default_bin_; uint32_t default_bin_;
}; };
/*! /*!
* \brief Interface for ordered bin data. efficient for construct histogram, especially for sparse bin * \brief Interface for ordered bin data. efficient for construct histogram, especially for sparse bin
* There are 2 advantages by using ordered bin. * There are 2 advantages by using ordered bin.
* 1. group the data by leafs to improve the cache hit. * 1. group the data by leafs to improve the cache hit.
* 2. only store the non-zero bin, which can speed up the histogram consturction for sparse features. * 2. only store the non-zero bin, which can speed up the histogram consturction for sparse features.
* However it brings additional cost: it need re-order the bins after every split, which will cost much for dense feature. * However it brings additional cost: it need re-order the bins after every split, which will cost much for dense feature.
* So we only using ordered bin for sparse situations. * So we only using ordered bin for sparse situations.
*/ */
class OrderedBin { class OrderedBin {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~OrderedBin() {} virtual ~OrderedBin() {}
...@@ -260,11 +260,11 @@ namespace LightGBM { ...@@ -260,11 +260,11 @@ namespace LightGBM {
virtual void Split(int leaf, int right_leaf, const char* is_in_leaf, char mark) = 0; virtual void Split(int leaf, int right_leaf, const char* is_in_leaf, char mark) = 0;
virtual data_size_t NonZeroCount(int leaf) const = 0; virtual data_size_t NonZeroCount(int leaf) const = 0;
}; };
/*! \brief Iterator for one bin column */ /*! \brief Iterator for one bin column */
class BinIterator { class BinIterator {
public: public:
/*! /*!
* \brief Get bin data on specific row index * \brief Get bin data on specific row index
* \param idx Index of this data * \param idx Index of this data
...@@ -274,16 +274,16 @@ namespace LightGBM { ...@@ -274,16 +274,16 @@ namespace LightGBM {
virtual uint32_t RawGet(data_size_t idx) = 0; virtual uint32_t RawGet(data_size_t idx) = 0;
virtual void Reset(data_size_t idx) = 0; virtual void Reset(data_size_t idx) = 0;
virtual ~BinIterator() = default; virtual ~BinIterator() = default;
}; };
/*! /*!
* \brief Interface for bin data. This class will store bin data for one feature. * \brief Interface for bin data. This class will store bin data for one feature.
* unlike OrderedBin, this class will store data by original order. * unlike OrderedBin, this class will store data by original order.
* Note that it may cause cache misses when construct histogram, * Note that it may cause cache misses when construct histogram,
* but it doesn't need to re-order operation, So it will be faster than OrderedBin for dense feature * but it doesn't need to re-order operation, So it will be faster than OrderedBin for dense feature
*/ */
class Bin { class Bin {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~Bin() {} virtual ~Bin() {}
/*! /*!
...@@ -381,13 +381,29 @@ namespace LightGBM { ...@@ -381,13 +381,29 @@ namespace LightGBM {
* \param num_data Number of used data * \param num_data Number of used data
* \param lte_indices After called this function. The less or equal data indices will store on this object. * \param lte_indices After called this function. The less or equal data indices will store on this object.
* \param gt_indices After called this function. The greater data indices will store on this object. * \param gt_indices After called this function. The greater data indices will store on this object.
* \param bin_type type of bin
* \return The number of less than or equal data. * \return The number of less than or equal data.
*/ */
virtual data_size_t Split(uint32_t min_bin, uint32_t max_bin, virtual data_size_t Split(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t threshold, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t threshold,
data_size_t* data_indices, data_size_t num_data, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices, BinType bin_type) const = 0; data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
/*!
* \brief Split data according to threshold, if bin <= threshold, will put into left(lte_indices), else put into right(gt_indices)
* \param min_bin min_bin of current used feature
* \param max_bin max_bin of current used feature
* \param default_bin defualt bin if bin not in [min_bin, max_bin]
* \param threshold The split threshold.
* \param data_indices Used data indices. After called this function. The less than or equal data indices will store on this object.
* \param num_data Number of used data
* \param lte_indices After called this function. The less or equal data indices will store on this object.
* \param gt_indices After called this function. The greater data indices will store on this object.
* \return The number of less than or equal data.
*/
virtual data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, uint32_t threshold,
data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
/*! /*!
* \brief Create the ordered bin for this bin * \brief Create the ordered bin for this bin
...@@ -429,9 +445,9 @@ namespace LightGBM { ...@@ -429,9 +445,9 @@ namespace LightGBM {
* \return The bin data object * \return The bin data object
*/ */
static Bin* CreateSparseBin(data_size_t num_data, int num_bin); static Bin* CreateSparseBin(data_size_t num_data, int num_bin);
}; };
inline uint32_t BinMapper::ValueToBin(double value) const { inline uint32_t BinMapper::ValueToBin(double value) const {
if (std::isnan(value)) { if (std::isnan(value)) {
if (missing_type_ == MissingType::NaN) { if (missing_type_ == MissingType::NaN) {
return num_bin_ - 1; return num_bin_ - 1;
...@@ -467,7 +483,7 @@ namespace LightGBM { ...@@ -467,7 +483,7 @@ namespace LightGBM {
return num_bin_ - 1; return num_bin_ - 1;
} }
} }
} }
} // namespace LightGBM } // namespace LightGBM
......
...@@ -168,9 +168,14 @@ public: ...@@ -168,9 +168,14 @@ public:
uint32_t min_bin = bin_offsets_[sub_feature]; uint32_t min_bin = bin_offsets_[sub_feature];
uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1; uint32_t max_bin = bin_offsets_[sub_feature + 1] - 1;
uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin(); uint32_t default_bin = bin_mappers_[sub_feature]->GetDefaultBin();
if (bin_mappers_[sub_feature]->bin_type() == BinType::NumericalBin) {
auto missing_type = bin_mappers_[sub_feature]->missing_type(); auto missing_type = bin_mappers_[sub_feature]->missing_type();
return bin_data_->Split(min_bin, max_bin, default_bin, missing_type, default_left, return bin_data_->Split(min_bin, max_bin, default_bin, missing_type, default_left,
threshold, data_indices, num_data, lte_indices, gt_indices, bin_mappers_[sub_feature]->bin_type()); threshold, data_indices, num_data, lte_indices, gt_indices);
} else {
return bin_data_->SplitCategorical(min_bin, max_bin, default_bin, threshold, data_indices, num_data, lte_indices, gt_indices);
}
} }
/*! /*!
* \brief From bin to feature value * \brief From bin to feature value
......
...@@ -37,9 +37,8 @@ public: ...@@ -37,9 +37,8 @@ public:
* \brief Performing a split on tree leaves. * \brief Performing a split on tree leaves.
* \param leaf Index of leaf to be split * \param leaf Index of leaf to be split
* \param feature Index of feature; the converted index after removing useless features * \param feature Index of feature; the converted index after removing useless features
* \param bin_type type of this feature, numerical or categorical
* \param threshold Threshold(bin) of split
* \param real_feature Index of feature, the original index on data * \param real_feature Index of feature, the original index on data
* \param threshold_bin Threshold(bin) of split
* \param threshold_double Threshold on feature value * \param threshold_double Threshold on feature value
* \param left_value Model Left child output * \param left_value Model Left child output
* \param right_value Model Right child output * \param right_value Model Right child output
...@@ -50,10 +49,29 @@ public: ...@@ -50,10 +49,29 @@ public:
* \param default_left default direction for missing value * \param default_left default direction for missing value
* \return The index of new leaf. * \return The index of new leaf.
*/ */
int Split(int leaf, int feature, BinType bin_type, uint32_t threshold, int real_feature, int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value, double threshold_double, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type, bool default_left); data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type, bool default_left);
/*!
* \brief Performing a split on tree leaves, with categorical feature
* \param leaf Index of leaf to be split
* \param feature Index of feature; the converted index after removing useless features
* \param real_feature Index of feature, the original index on data
* \param threshold_bin Threshold(bin) of split, use bitset to represent
* \param num_threshold_bin size of threshold_bin
* \param threshold
* \param left_value Model Left child output
* \param right_value Model Right child output
* \param left_cnt Count of left child
* \param right_cnt Count of right child
* \param gain Split gain
* \return The index of new leaf.
*/
int SplitCategorical(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type);
/*! \brief Get the output of one leaf */ /*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; } inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
...@@ -89,6 +107,7 @@ public: ...@@ -89,6 +107,7 @@ public:
* \return Prediction result * \return Prediction result
*/ */
inline double Predict(const double* feature_values) const; inline double Predict(const double* feature_values) const;
inline int PredictLeafIndex(const double* feature_values) const; inline int PredictLeafIndex(const double* feature_values) const;
inline void PredictContrib(const double* feature_values, int num_features, double* output) const; inline void PredictContrib(const double* feature_values, int num_features, double* output) const;
...@@ -139,7 +158,7 @@ public: ...@@ -139,7 +158,7 @@ public:
* \param rate The factor of shrinkage * \param rate The factor of shrinkage
*/ */
inline void Shrinkage(double rate) { inline void Shrinkage(double rate) {
#pragma omp parallel for schedule(static, 512) if (num_leaves_ >= 1024) #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate; leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; } if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; }
...@@ -157,24 +176,6 @@ public: ...@@ -157,24 +176,6 @@ public:
/*! \brief Serialize this object to if-else statement*/ /*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index); std::string ToIfElse(int index, bool is_predict_leaf_index);
template<typename T>
inline static bool CategoricalDecision(T fval, T threshold) {
if (static_cast<int>(fval) == static_cast<int>(threshold)) {
return true;
} else {
return false;
}
}
template<typename T>
inline static bool NumericalDecision(T fval, T threshold) {
if (fval <= threshold) {
return true;
} else {
return false;
}
}
inline static bool IsZero(double fval) { inline static bool IsZero(double fval) {
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) { if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
return true; return true;
...@@ -204,21 +205,44 @@ public: ...@@ -204,21 +205,44 @@ public:
(*decision_type) |= (input << 2); (*decision_type) |= (input << 2);
} }
inline static uint32_t ConvertMissingValue(uint32_t fval, uint32_t threshold, int8_t decision_type, uint32_t default_bin, uint32_t max_bin) { private:
uint8_t missing_type = GetMissingType(decision_type);
if ((missing_type == 1 && fval == default_bin) inline std::string NumericalDecisionIfElse(int node) {
|| (missing_type == 2 && fval == max_bin)) { std::stringstream str_buf;
if (GetDecisionType(decision_type, kDefaultLeftMask)) { uint8_t missing_type = GetMissingType(decision_type_[node]);
fval = threshold; bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == 0 || (missing_type == 1 && default_left && kZeroAsMissingValueRange < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == 1) {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else { } else {
fval = threshold + 1; str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
} }
} else {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
} }
return fval; }
return str_buf.str();
} }
inline static double ConvertMissingValue(double fval, double threshold, int8_t decision_type) { inline std::string CategoricalDecisionIfElse(int node) const {
uint8_t missing_type = GetMissingType(decision_type); uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf;
if (missing_type == 2) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}
str_buf << "if (int_fval >= 0 && int_fval == " << static_cast<int>(threshold_[node]) << ") {";
return str_buf.str();
}
inline int NumericalDecision(double fval, int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
if (std::isnan(fval)) { if (std::isnan(fval)) {
if (missing_type != 2) { if (missing_type != 2) {
fval = 0.0f; fval = 0.0f;
...@@ -226,28 +250,79 @@ public: ...@@ -226,28 +250,79 @@ public:
} }
if ((missing_type == 1 && IsZero(fval)) if ((missing_type == 1 && IsZero(fval))
|| (missing_type == 2 && std::isnan(fval))) { || (missing_type == 2 && std::isnan(fval))) {
if (GetDecisionType(decision_type, kDefaultLeftMask)) { if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
fval = threshold; return left_child_[node];
} else { } else {
fval = 10.0f * threshold; return right_child_[node];
}
} }
if (fval <= threshold_[node]) {
return left_child_[node];
} else {
return right_child_[node];
} }
return fval;
} }
inline static const char* GetDecisionTypeName(int8_t type) { inline int NumericalDecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const {
if (type == 0) { uint8_t missing_type = GetMissingType(decision_type_[node]);
return "no_greater"; if ((missing_type == 1 && fval == default_bin)
|| (missing_type == 2 && fval == max_bin)) {
if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
return left_child_[node];
} else {
return right_child_[node];
}
}
if (fval <= threshold_in_bin_[node]) {
return left_child_[node];
} else { } else {
return "is"; return right_child_[node];
} }
} }
static std::vector<bool(*)(uint32_t, uint32_t)> inner_decision_funs; inline int CategoricalDecision(double fval, int node) const {
static std::vector<bool(*)(double, double)> decision_funs; uint8_t missing_type = GetMissingType(decision_type_[node]);
int int_fval = static_cast<int>(fval);
if (int_fval < 0) {
return right_child_[node];;
} else if (std::isnan(fval)) {
// NaN is always in the right
if (missing_type == 2) {
return right_child_[node];
}
int_fval = 0;
}
if (int_fval == static_cast<int>(threshold_[node])) {
return left_child_[node];
}
return right_child_[node];
}
private: inline int CategoricalDecisionInner(uint32_t fval, int node) const {
if (fval == threshold_in_bin_[node]) {
return left_child_[node];
}
return right_child_[node];
}
inline int Decision(double fval, int node) const {
if (GetDecisionType(decision_type_[node], kCategoricalMask)) {
return CategoricalDecision(fval, node);
} else {
return NumericalDecision(fval, node);
}
}
inline int DecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const {
if (GetDecisionType(decision_type_[node], kCategoricalMask)) {
return CategoricalDecisionInner(fval, node);
} else {
return NumericalDecisionInner(fval, node, default_bin, max_bin);
}
}
inline void Split(int leaf, int feature, int real_feature,
double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain);
/*! /*!
* \brief Find leaf index of which record belongs by features * \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record * \param feature_values Feature value of this record
...@@ -288,6 +363,7 @@ private: ...@@ -288,6 +363,7 @@ private:
std::vector<uint32_t> threshold_in_bin_; std::vector<uint32_t> threshold_in_bin_;
/*! \brief A non-leaf node's split threshold in feature value */ /*! \brief A non-leaf node's split threshold in feature value */
std::vector<double> threshold_; std::vector<double> threshold_;
int num_cat_;
/*! \brief Store the information for categorical feature handle and mising value handle. */ /*! \brief Store the information for categorical feature handle and mising value handle. */
std::vector<int8_t> decision_type_; std::vector<int8_t> decision_type_;
/*! \brief A non-leaf node's split gain */ /*! \brief A non-leaf node's split gain */
...@@ -306,9 +382,44 @@ private: ...@@ -306,9 +382,44 @@ private:
/*! \brief Depth for leaves */ /*! \brief Depth for leaves */
std::vector<int> leaf_depth_; std::vector<int> leaf_depth_;
double shrinkage_; double shrinkage_;
bool has_categorical_;
}; };
inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain) {
int new_node_idx = num_leaves_ - 1;
// update parent info
int parent = leaf_parent_[leaf];
if (parent >= 0) {
// if cur node is left child
if (left_child_[parent] == ~leaf) {
left_child_[parent] = new_node_idx;
} else {
right_child_[parent] = new_node_idx;
}
}
// add new node
split_feature_inner_[new_node_idx] = feature;
split_feature_[new_node_idx] = real_feature;
split_gain_[new_node_idx] = Common::AvoidInf(gain);
// add two new leaves
left_child_[new_node_idx] = ~leaf;
right_child_[new_node_idx] = ~num_leaves_;
// update new leaves
leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change
internal_value_[new_node_idx] = leaf_value_[leaf];
internal_count_[new_node_idx] = left_cnt + right_cnt;
leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
leaf_count_[leaf] = left_cnt;
leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
leaf_count_[num_leaves_] = right_cnt;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
}
inline double Tree::Predict(const double* feature_values) const { inline double Tree::Predict(const double* feature_values) const {
if (num_leaves_ > 1) { if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
...@@ -409,8 +520,7 @@ inline void Tree::TreeSHAP(const double *feature_values, double *phi, ...@@ -409,8 +520,7 @@ inline void Tree::TreeSHAP(const double *feature_values, double *phi,
// internal node // internal node
} else { } else {
const int hot_index = const int hot_index = Decision(feature_values[split_index], node);
decision_funs[GetDecisionType(decision_type_[node], kCategoricalMask)](feature_values[split_index], threshold_[node]);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]); const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
const double w = data_count(node); const double w = data_count(node);
const double hot_zero_fraction = data_count(hot_index)/w; const double hot_zero_fraction = data_count(hot_index)/w;
...@@ -469,27 +579,13 @@ inline int Tree::MaxDepth() const { ...@@ -469,27 +579,13 @@ inline int Tree::MaxDepth() const {
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0; int node = 0;
if (has_categorical_) { if (num_cat_ > 0) {
while (node >= 0) { while (node >= 0) {
double fval = ConvertMissingValue(feature_values[split_feature_[node]], threshold_[node], decision_type_[node]); node = Decision(feature_values[split_feature_[node]], node);
if (decision_funs[GetDecisionType(decision_type_[node], kCategoricalMask)](
fval,
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
}
} }
} else { } else {
while (node >= 0) { while (node >= 0) {
double fval = ConvertMissingValue(feature_values[split_feature_[node]], threshold_[node], decision_type_[node]); node = NumericalDecision(feature_values[split_feature_[node]], node);
if (NumericalDecision<double>(
fval,
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
}
} }
} }
return ~node; return ~node;
......
...@@ -473,7 +473,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -473,7 +473,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
auto label = train_data_->metadata().label(); auto label = train_data_->metadata().label();
double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_); double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_);
std::unique_ptr<Tree> new_tree(new Tree(2)); std::unique_ptr<Tree> new_tree(new Tree(2));
new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, init_score, init_score, 0, 0, -1, MissingType::None, true); new_tree->Split(0, 0, 0, 0, 0, init_score, init_score, 0, 0, -1, MissingType::None, true);
train_score_updater_->AddScore(init_score, 0); train_score_updater_->AddScore(init_score, 0);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(init_score, 0); score_updater->AddScore(init_score, 0);
...@@ -553,7 +553,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -553,7 +553,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// only add default score one-time // only add default score one-time
if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) { if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
auto output = class_default_output_[cur_tree_id]; auto output = class_default_output_[cur_tree_id];
new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, new_tree->Split(0, 0, 0, 0, 0,
output, output, 0, 0, -1, MissingType::None, true); output, output, 0, 0, -1, MissingType::None, true);
train_score_updater_->AddScore(output, cur_tree_id); train_score_updater_->AddScore(output, cur_tree_id);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
......
...@@ -127,7 +127,7 @@ public: ...@@ -127,7 +127,7 @@ public:
if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) { if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
double output = class_default_output_[cur_tree_id]; double output = class_default_output_[cur_tree_id];
objective_function_->ConvertOutput(&output, &output); objective_function_->ConvertOutput(&output, &output);
new_tree->Split(0, 0, BinType::NumericalBin, 0, 0, 0, new_tree->Split(0, 0, 0, 0, 0,
output, output, 0, 0, -1, MissingType::None, true); output, output, 0, 0, -1, MissingType::None, true);
train_score_updater_->AddScore(output, cur_tree_id); train_score_updater_->AddScore(output, cur_tree_id);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
......
...@@ -190,11 +190,11 @@ public: ...@@ -190,11 +190,11 @@ public:
virtual data_size_t Split( virtual data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data, uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices, BinType bin_type) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
VAL_T th = static_cast<VAL_T>(threshold + min_bin); VAL_T th = static_cast<VAL_T>(threshold + min_bin);
VAL_T minb = static_cast<VAL_T>(min_bin); const VAL_T minb = static_cast<VAL_T>(min_bin);
VAL_T maxb = static_cast<VAL_T>(max_bin); const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin); VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
if (default_bin == 0) { if (default_bin == 0) {
th -= 1; th -= 1;
...@@ -204,16 +204,11 @@ public: ...@@ -204,16 +204,11 @@ public:
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (bin_type == BinType::NumericalBin) { if (missing_type == MissingType::NaN) {
if (missing_type != MissingType::Zero && default_bin <= threshold) { if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (default_left && missing_type == MissingType::Zero) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
if (missing_type == MissingType::NaN) {
data_size_t* missing_default_indices = gt_indices; data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count; data_size_t* missing_default_count = &gt_count;
if (default_left) { if (default_left) {
...@@ -222,7 +217,7 @@ public: ...@@ -222,7 +217,7 @@ public:
} }
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
VAL_T bin = data_[idx]; const VAL_T bin = data_[idx];
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin == maxb) { } else if (bin == maxb) {
...@@ -234,9 +229,13 @@ public: ...@@ -234,9 +229,13 @@ public:
} }
} }
} else { } else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_indices = lte_indices;
default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
VAL_T bin = data_[idx]; const VAL_T bin = data_[idx];
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin > th) { } else if (bin > th) {
...@@ -246,21 +245,31 @@ public: ...@@ -246,21 +245,31 @@ public:
} }
} }
} }
} else { return lte_count;
if (default_bin == threshold) { }
virtual data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (threshold == default_bin) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
VAL_T bin = data_[idx]; const uint32_t bin = data_[idx];
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < min_bin || bin > max_bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin != th) { } else if (bin - min_bin == threshold) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx; lte_indices[lte_count++] = idx;
} } else {
gt_indices[gt_count++] = idx;
} }
} }
return lte_count; return lte_count;
......
...@@ -229,11 +229,11 @@ public: ...@@ -229,11 +229,11 @@ public:
virtual data_size_t Split( virtual data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data, uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices, BinType bin_type) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
uint8_t th = static_cast<uint8_t>(threshold + min_bin); uint8_t th = static_cast<uint8_t>(threshold + min_bin);
uint8_t minb = static_cast<uint8_t>(min_bin); const uint8_t minb = static_cast<uint8_t>(min_bin);
uint8_t maxb = static_cast<uint8_t>(max_bin); const uint8_t maxb = static_cast<uint8_t>(max_bin);
uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin); uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin);
if (default_bin == 0) { if (default_bin == 0) {
th -= 1; th -= 1;
...@@ -243,16 +243,11 @@ public: ...@@ -243,16 +243,11 @@ public:
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (bin_type == BinType::NumericalBin) { if (missing_type == MissingType::NaN) {
if (missing_type != MissingType::Zero && default_bin <= threshold) { if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (default_left && missing_type == MissingType::Zero) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
if (missing_type == MissingType::NaN) {
data_size_t* missing_default_indices = gt_indices; data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count; data_size_t* missing_default_count = &gt_count;
if (default_left) { if (default_left) {
...@@ -261,7 +256,7 @@ public: ...@@ -261,7 +256,7 @@ public:
} }
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin == maxb) { } else if (bin == maxb) {
...@@ -273,9 +268,13 @@ public: ...@@ -273,9 +268,13 @@ public:
} }
} }
} else { } else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_indices = lte_indices;
default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin > th) { } else if (bin > th) {
...@@ -285,25 +284,36 @@ public: ...@@ -285,25 +284,36 @@ public:
} }
} }
} }
} else { return lte_count;
}
virtual data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (default_bin == threshold) { if (default_bin == threshold) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
const auto bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf; const uint32_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < min_bin || bin > max_bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin != th) { } else if (bin - min_bin == threshold) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx; lte_indices[lte_count++] = idx;
} } else {
gt_indices[gt_count++] = idx;
} }
} }
return lte_count; return lte_count;
} }
data_size_t num_data() const override { return num_data_; } data_size_t num_data() const override { return num_data_; }
/*! \brief not ordered bin for dense feature */ /*! \brief not ordered bin for dense feature */
......
...@@ -144,12 +144,12 @@ public: ...@@ -144,12 +144,12 @@ public:
virtual data_size_t Split( virtual data_size_t Split(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left, uint32_t min_bin, uint32_t max_bin, uint32_t default_bin, MissingType missing_type, bool default_left,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data, uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices, BinType bin_type) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
// not need to split // not need to split
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
VAL_T th = static_cast<VAL_T>(threshold + min_bin); VAL_T th = static_cast<VAL_T>(threshold + min_bin);
VAL_T minb = static_cast<VAL_T>(min_bin); const VAL_T minb = static_cast<VAL_T>(min_bin);
VAL_T maxb = static_cast<VAL_T>(max_bin); const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin); VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
if (default_bin == 0) { if (default_bin == 0) {
th -= 1; th -= 1;
...@@ -160,16 +160,11 @@ public: ...@@ -160,16 +160,11 @@ public:
data_size_t gt_count = 0; data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices; data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count; data_size_t* default_count = &gt_count;
if (bin_type == BinType::NumericalBin) { if (missing_type == MissingType::NaN) {
if (missing_type != MissingType::Zero && default_bin <= threshold) { if (default_bin <= threshold) {
default_indices = lte_indices;
default_count = &lte_count;
}
if (default_left && missing_type == MissingType::Zero) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
if (missing_type == MissingType::NaN) {
data_size_t* missing_default_indices = gt_indices; data_size_t* missing_default_indices = gt_indices;
data_size_t* missing_default_count = &gt_count; data_size_t* missing_default_count = &gt_count;
if (default_left) { if (default_left) {
...@@ -178,7 +173,7 @@ public: ...@@ -178,7 +173,7 @@ public:
} }
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
VAL_T bin = iterator.InnerRawGet(idx); const VAL_T bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin == maxb) { } else if (bin == maxb) {
...@@ -190,9 +185,13 @@ public: ...@@ -190,9 +185,13 @@ public:
} }
} }
} else { } else {
if ((default_left && missing_type == MissingType::Zero) || (default_bin <= threshold && missing_type != MissingType::Zero)) {
default_indices = lte_indices;
default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
VAL_T bin = iterator.InnerRawGet(idx); const VAL_T bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < minb || bin > maxb || t_default_bin == bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin > th) { } else if (bin > th) {
...@@ -202,21 +201,32 @@ public: ...@@ -202,21 +201,32 @@ public:
} }
} }
} }
} else { return lte_count;
}
virtual data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
data_size_t lte_count = 0;
data_size_t gt_count = 0;
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (default_bin == threshold) { if (default_bin == threshold) {
default_indices = lte_indices; default_indices = lte_indices;
default_count = &lte_count; default_count = &lte_count;
} }
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i]; const data_size_t idx = data_indices[i];
VAL_T bin = iterator.InnerRawGet(idx); uint32_t bin = iterator.InnerRawGet(idx);
if (bin < minb || bin > maxb || t_default_bin == bin) { if (bin < min_bin || bin > max_bin) {
default_indices[(*default_count)++] = idx; default_indices[(*default_count)++] = idx;
} else if (bin != th) { } else if (bin - min_bin == threshold) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx; lte_indices[lte_count++] = idx;
} } else {
gt_indices[gt_count++] = idx;
} }
} }
return lte_count; return lte_count;
......
This diff is collapsed.
...@@ -84,9 +84,9 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks, ...@@ -84,9 +84,9 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, double DCGCalculator::CalDCGAtK(data_size_t k, const float* label,
const double* score, data_size_t num_data) { const double* score, data_size_t num_data) {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx(num_data);
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
sorted_idx.emplace_back(i); sorted_idx[i] = i;
} }
std::sort(sorted_idx.begin(), sorted_idx.end(), std::sort(sorted_idx.begin(), sorted_idx.end(),
[score](data_size_t a, data_size_t b) {return score[a] > score[b]; }); [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
...@@ -104,9 +104,9 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label, ...@@ -104,9 +104,9 @@ double DCGCalculator::CalDCGAtK(data_size_t k, const float* label,
void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* label, void DCGCalculator::CalDCG(const std::vector<data_size_t>& ks, const float* label,
const double * score, data_size_t num_data, std::vector<double>* out) { const double * score, data_size_t num_data, std::vector<double>* out) {
// get sorted indices by score // get sorted indices by score
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx(num_data);
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
sorted_idx.emplace_back(i); sorted_idx[i] = i;
} }
std::sort(sorted_idx.begin(), sorted_idx.end(), std::sort(sorted_idx.begin(), sorted_idx.end(),
[score](data_size_t a, data_size_t b) {return score[a] > score[b]; }); [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
......
...@@ -516,17 +516,17 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -516,17 +516,17 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
} }
void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf) {
const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf]; const SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature); const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature);
// left = parent // left = parent
*left_leaf = best_Leaf; *left_leaf = best_leaf;
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) {
// split tree, will return right leaf // split tree, will return right leaf
*right_leaf = tree->Split(best_Leaf, *right_leaf = tree->Split(best_leaf,
inner_feature_index, inner_feature_index,
train_data_->FeatureBinMapper(inner_feature_index)->bin_type(),
best_split_info.threshold,
best_split_info.feature, best_split_info.feature,
best_split_info.threshold,
train_data_->RealThreshold(inner_feature_index, best_split_info.threshold), train_data_->RealThreshold(inner_feature_index, best_split_info.threshold),
static_cast<double>(best_split_info.left_output), static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output), static_cast<double>(best_split_info.right_output),
...@@ -535,8 +535,21 @@ void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* ri ...@@ -535,8 +535,21 @@ void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.gain), static_cast<double>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(), train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left); best_split_info.default_left);
} else {
*right_leaf = tree->SplitCategorical(best_leaf,
inner_feature_index,
best_split_info.feature,
best_split_info.threshold,
train_data_->RealThreshold(inner_feature_index, best_split_info.threshold),
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
}
// split data partition // split data partition
data_partition_->Split(best_Leaf, train_data_, inner_feature_index, data_partition_->Split(best_leaf, train_data_, inner_feature_index,
best_split_info.threshold, best_split_info.default_left, *right_leaf); best_split_info.threshold, best_split_info.default_left, *right_leaf);
// init the leaves that used on next iteration // init the leaves that used on next iteration
......
...@@ -218,6 +218,7 @@ ...@@ -218,6 +218,7 @@
<ClInclude Include="..\src\boosting\gbdt.h" /> <ClInclude Include="..\src\boosting\gbdt.h" />
<ClInclude Include="..\src\boosting\dart.hpp" /> <ClInclude Include="..\src\boosting\dart.hpp" />
<ClInclude Include="..\src\boosting\goss.hpp" /> <ClInclude Include="..\src\boosting\goss.hpp" />
<ClInclude Include="..\src\boosting\rf.hpp" />
<ClInclude Include="..\src\boosting\score_updater.hpp" /> <ClInclude Include="..\src\boosting\score_updater.hpp" />
<ClInclude Include="..\src\io\dense_bin.hpp" /> <ClInclude Include="..\src\io\dense_bin.hpp" />
<ClInclude Include="..\src\io\dense_nbits_bin.hpp" /> <ClInclude Include="..\src\io\dense_nbits_bin.hpp" />
......
...@@ -192,6 +192,9 @@ ...@@ -192,6 +192,9 @@
<ClInclude Include="..\include\LightGBM\R_object_helper.h"> <ClInclude Include="..\include\LightGBM\R_object_helper.h">
<Filter>include\LightGBM</Filter> <Filter>include\LightGBM</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\boosting\rf.hpp">
<Filter>src\boosting</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\src\application\application.cpp"> <ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment