Unverified Commit bca2da97 authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Interaction constraints (#3126)

* Add interaction constraints functionality.

* Minor fixes.

* Minor fixes.

* Change lambda to function.

* Fix gpu bug, remove extra blank lines.

* Fix gpu bug.

* Fix style issues.

* Try to fix segfault on MACOS.

* Fix bug.

* Fix bug.

* Fix bugs.

* Change parameter format for R.

* Fix R style issues.

* Change string formatting code.

* Change docs to say R package not supported.

* Remove R functionality, moving to separate PR.

* Keep track of branch features in tree object.

* Only track branch features when feature interactions are enabled.

* Fix lint error.

* Update docs and simplify tests.
parent f5e51649
......@@ -124,6 +124,10 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L
}
if (!is.null(params[["interaction_constraints"]])) {
stop("lgb.train: interaction_constraints is not implemented")
}
# Update parameters with parsed parameters
data$update_params(params)
......
......@@ -538,6 +538,20 @@ Learning Control Parameters
- note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth
- ``interaction_constraints`` :raw-html:`<a id="interaction_constraints" title="Permalink to this parameter" href="#interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
- controls which features can appear in the same branch
- by default interaction constraints are disabled, to enable them you can specify
- for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
- for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
- for R-package, **not yet supported**
- any two features can only appear in the same branch only if there exists a constraint containing both features
- ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``
- controls the level of LightGBM's verbosity
......
......@@ -505,6 +505,14 @@ struct Config {
// descl2 = note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth
double path_smooth = 0;
// desc = controls which features can appear in the same branch
// desc = by default interaction constraints are disabled, to enable them you can specify
// descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
// descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
// descl2 = for R-package, **not yet supported**
// desc = any two features can only appear in the same branch only if there exists a constraint containing both features
std::string interaction_constraints = "";
// alias = verbose
// desc = controls the level of LightGBM's verbosity
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
......@@ -958,12 +966,14 @@ struct Config {
static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
private:
void CheckParamConflict();
void GetMembersFromString(const std::unordered_map<std::string, std::string>& params);
std::string SaveMembersToString() const;
void GetAucMuWeights();
void GetInteractionConstraints();
};
inline bool Config::GetString(
......
......@@ -27,8 +27,9 @@ class Tree {
/*!
* \brief Constructor
* \param max_leaves The number of max leaves
* \param track_branch_features Whether to keep track of ancestors of leaf nodes
*/
explicit Tree(int max_leaves);
explicit Tree(int max_leaves, bool track_branch_features);
/*!
* \brief Constructor, from a string
......@@ -148,6 +149,9 @@ class Tree {
/*! \brief Get feature of specific split*/
inline int split_feature(int split_idx) const { return split_feature_[split_idx]; }
/*! \brief Get features on leaf's branch*/
inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
inline double internal_value(int node_idx) const {
......@@ -436,6 +440,10 @@ class Tree {
std::vector<int> internal_count_;
/*! \brief Depth for leaves */
std::vector<int> leaf_depth_;
/*! \brief whether to keep track of ancestor nodes for each leaf (only needed when feature interactions are restricted) */
bool track_branch_features_;
/*! \brief Features on leaf's branch, original index */
std::vector<std::vector<int>> branch_features_;
double shrinkage_;
int max_depth_;
};
......@@ -477,6 +485,11 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
if (track_branch_features_) {
branch_features_[num_leaves_] = branch_features_[leaf];
branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
branch_features_[leaf].push_back(split_feature_[new_node_idx]);
}
}
inline double Tree::Predict(const double* feature_values) const {
......
......@@ -103,6 +103,30 @@ inline static std::vector<std::string> Split(const char* c_str, char delimiter)
return ret;
}
inline static std::vector<std::string> SplitBrackets(const char* c_str, char left_delimiter, char right_delimiter) {
std::vector<std::string> ret;
std::string str(c_str);
size_t i = 0;
size_t pos = 0;
bool open = false;
while (pos < str.length()) {
if (str[pos] == left_delimiter) {
open = true;
++pos;
i = pos;
} else if (str[pos] == right_delimiter && open) {
if (i < pos) {
ret.push_back(str.substr(i, pos - i));
}
open = false;
++pos;
} else {
++pos;
}
}
return ret;
}
inline static std::vector<std::string> SplitLines(const char* c_str) {
std::vector<std::string> ret;
std::string str(c_str);
......@@ -503,6 +527,17 @@ inline static std::vector<T> StringToArray(const std::string& str, char delimite
return ret;
}
template<typename T>
inline static std::vector<std::vector<T>> StringToArrayofArrays(
const std::string& str, char left_bracket, char right_bracket, char delimiter) {
std::vector<std::string> strs = SplitBrackets(str.c_str(), left_bracket, right_bracket);
std::vector<std::vector<T>> ret;
for (const auto& s : strs) {
ret.push_back(StringToArray<T>(s, delimiter));
}
return ret;
}
template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, int n) {
if (n == 0) {
......
......@@ -135,7 +135,12 @@ def param_dict_to_str(data):
pairs = []
for key, val in data.items():
if isinstance(val, (list, tuple, set)) or is_numpy_1d_array(val):
pairs.append(str(key) + '=' + ','.join(map(str, val)))
def to_string(x):
if isinstance(x, list):
return "[{}]".format(','.join(map(str, x)))
else:
return str(x)
pairs.append(str(key) + '=' + ','.join(map(to_string, val)))
elif isinstance(val, string_type) or isinstance(val, numeric_types) or is_numeric(val):
pairs.append(str(key) + '=' + str(val))
elif val is not None:
......
......@@ -352,7 +352,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
bool should_continue = false;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
const size_t offset = static_cast<size_t>(cur_tree_id) * num_data_;
std::unique_ptr<Tree> new_tree(new Tree(2));
std::unique_ptr<Tree> new_tree(new Tree(2, false));
if (class_need_train_[cur_tree_id] && train_data_->num_features() > 0) {
auto grad = gradients + offset;
auto hess = hessians + offset;
......
......@@ -109,7 +109,7 @@ class RF : public GBDT {
gradients = gradients_.data();
hessians = hessians_.data();
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
std::unique_ptr<Tree> new_tree(new Tree(2));
std::unique_ptr<Tree> new_tree(new Tree(2, false));
size_t offset = static_cast<size_t>(cur_tree_id)* num_data_;
if (class_need_train_[cur_tree_id]) {
auto grad = gradients + offset;
......
......@@ -180,6 +180,14 @@ void Config::GetAucMuWeights() {
}
}
void Config::GetInteractionConstraints() {
if (interaction_constraints == "") {
interaction_constraints_vector = std::vector<std::vector<int>>();
} else {
interaction_constraints_vector = Common::StringToArrayofArrays<int>(interaction_constraints, '[', ']', ',');
}
}
void Config::Set(const std::unordered_map<std::string, std::string>& params) {
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
......@@ -204,6 +212,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
GetAucMuWeights();
GetInteractionConstraints();
// sort eval_at
std::sort(eval_at.begin(), eval_at.end());
......
......@@ -230,6 +230,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"cegb_penalty_feature_lazy",
"cegb_penalty_feature_coupled",
"path_smooth",
"interaction_constraints",
"verbosity",
"input_model",
"output_model",
......@@ -454,6 +455,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "path_smooth", &path_smooth);
CHECK_GE(path_smooth, 0.0);
GetString(params, "interaction_constraints", &interaction_constraints);
GetInt(params, "verbosity", &verbosity);
GetString(params, "input_model", &input_model);
......@@ -659,6 +662,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_lazy: " << Common::Join(cegb_penalty_feature_lazy, ",") << "]\n";
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[path_smooth: " << path_smooth << "]\n";
str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
......
......@@ -14,8 +14,8 @@
namespace LightGBM {
Tree::Tree(int max_leaves)
:max_leaves_(max_leaves) {
Tree::Tree(int max_leaves, bool track_branch_features)
:max_leaves_(max_leaves), track_branch_features_(track_branch_features) {
left_child_.resize(max_leaves_ - 1);
right_child_.resize(max_leaves_ - 1);
split_feature_inner_.resize(max_leaves_ - 1);
......@@ -32,6 +32,9 @@ Tree::Tree(int max_leaves)
internal_weight_.resize(max_leaves_ - 1);
internal_count_.resize(max_leaves_ - 1);
leaf_depth_.resize(max_leaves_);
if (track_branch_features_) {
branch_features_ = std::vector<std::vector<int>>(max_leaves_);
}
// root is in the depth 0
leaf_depth_[0] = 0;
num_leaves_ = 1;
......
......@@ -13,6 +13,7 @@
#include <LightGBM/utils/random.h>
#include <algorithm>
#include <unordered_set>
#include <vector>
namespace LightGBM {
......@@ -23,6 +24,10 @@ class ColSampler {
fraction_bynode_(config->feature_fraction_bynode),
seed_(config->feature_fraction_seed),
random_(config->feature_fraction_seed) {
for (auto constraint : config->interaction_constraints_vector) {
std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
interaction_constraints_.push_back(constraint_set);
}
}
static int GetCnt(size_t total_cnt, double fraction) {
......@@ -83,32 +88,87 @@ class ColSampler {
}
}
std::vector<int8_t> GetByNode() {
std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
// get interaction constraints for current branch
std::unordered_set<int> allowed_features;
if (!interaction_constraints_.empty()) {
std::vector<int> branch_features = tree->branch_features(leaf);
allowed_features.insert(branch_features.begin(), branch_features.end());
for (auto constraint : interaction_constraints_) {
int num_feat_found = 0;
if (branch_features.size() == 0) {
allowed_features.insert(constraint.begin(), constraint.end());
}
for (int feat : branch_features) {
if (constraint.count(feat) == 0) { break; }
++num_feat_found;
if (num_feat_found == static_cast<int>(branch_features.size())) {
allowed_features.insert(constraint.begin(), constraint.end());
break;
}
}
}
}
std::vector<int8_t> ret(train_data_->num_features(), 0);
if (fraction_bynode_ >= 1.0f) {
if (interaction_constraints_.empty()) {
return std::vector<int8_t>(train_data_->num_features(), 1);
} else {
for (int feat : allowed_features) {
int inner_feat = train_data_->InnerFeatureIndex(feat);
ret[inner_feat] = 1;
}
return ret;
}
}
std::vector<int8_t> ret(train_data_->num_features(), 0);
if (need_reset_bytree_) {
auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
std::vector<int>* allowed_used_feature_indices;
std::vector<int> filtered_feature_indices;
if (interaction_constraints_.empty()) {
allowed_used_feature_indices = &used_feature_indices_;
} else {
for (int feat_ind : used_feature_indices_) {
if (allowed_features.count(valid_feature_indices_[feat_ind]) == 1) {
filtered_feature_indices.push_back(feat_ind);
}
}
used_feature_cnt = std::min(used_feature_cnt, static_cast<int>(filtered_feature_indices.size()));
allowed_used_feature_indices = &filtered_feature_indices;
}
auto sampled_indices = random_.Sample(
static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
static_cast<int>((*allowed_used_feature_indices).size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature =
valid_feature_indices_[used_feature_indices_[sampled_indices[i]]];
valid_feature_indices_[(*allowed_used_feature_indices)[sampled_indices[i]]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
} else {
auto used_feature_cnt =
GetCnt(valid_feature_indices_.size(), fraction_bynode_);
std::vector<int>* allowed_valid_feature_indices;
std::vector<int> filtered_feature_indices;
if (interaction_constraints_.empty()) {
allowed_valid_feature_indices = &valid_feature_indices_;
} else {
for (int feat : valid_feature_indices_) {
if (allowed_features.count(feat) == 1) {
filtered_feature_indices.push_back(feat);
}
}
allowed_valid_feature_indices = &filtered_feature_indices;
used_feature_cnt = std::min(used_feature_cnt, static_cast<int>(filtered_feature_indices.size()));
}
auto sampled_indices = random_.Sample(
static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
static_cast<int>((*allowed_valid_feature_indices).size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[sampled_indices[i]];
int used_feature = (*allowed_valid_feature_indices)[sampled_indices[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
......@@ -135,6 +195,8 @@ class ColSampler {
std::vector<int8_t> is_feature_used_;
std::vector<int> used_feature_indices_;
std::vector<int> valid_feature_indices_;
/*! \brief interaction constraints index in original (raw data) features */
std::vector<std::unordered_set<int>> interaction_constraints_;
};
} // namespace LightGBM
......
......@@ -152,7 +152,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
TREELEARNER_T::ConstructHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
// construct local histograms
......@@ -169,17 +169,17 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(),
block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
this->FindBestSplitsFromHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
this->col_sampler_.is_feature_used_bytree(), true, tree);
}
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool, const Tree* tree) {
std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features =
this->col_sampler_.GetByNode();
this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
std::vector<int8_t> larger_node_used_features =
this->col_sampler_.GetByNode();
this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
......
......@@ -57,8 +57,9 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
}
template <typename TREELEARNER_T>
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(
const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) {
TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
......
......@@ -1055,8 +1055,8 @@ void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_u
}
}
void GPUTreeLearner::FindBestSplits() {
SerialTreeLearner::FindBestSplits();
void GPUTreeLearner::FindBestSplits(const Tree* tree) {
SerialTreeLearner::FindBestSplits(tree);
#if GPU_DEBUG >= 3
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
......
......@@ -66,7 +66,7 @@ class GPUTreeLearner: public SerialTreeLearner {
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestSplits() override;
void FindBestSplits(const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
......
......@@ -31,7 +31,7 @@ class FeatureParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) override;
private:
/*! \brief rank of local machine */
......@@ -59,8 +59,8 @@ class DataParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
void FindBestSplits() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void FindBestSplits(const Tree* tree) override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
......@@ -114,8 +114,8 @@ class VotingParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
void FindBestSplits() override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) override;
void FindBestSplits(const Tree* tree) override;
void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
......
......@@ -163,7 +163,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training
BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves));
bool track_branch_features = !(config_->interaction_constraints_vector.empty());
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features));
auto tree_prt = tree.get();
constraints_->ShareTreePointer(tree_prt);
......@@ -179,7 +180,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before finding best split
if (BeforeFindBestSplit(tree_prt, left_leaf, right_leaf)) {
// find best threshold for every feature
FindBestSplits();
FindBestSplits(tree_prt);
}
// Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
......@@ -310,7 +311,7 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true;
}
void SerialTreeLearner::FindBestSplits() {
void SerialTreeLearner::FindBestSplits(const Tree* tree) {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
......@@ -324,7 +325,7 @@ void SerialTreeLearner::FindBestSplits() {
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
ConstructHistograms(is_feature_used, use_subtract);
FindBestSplitsFromHistograms(is_feature_used, use_subtract);
FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}
void SerialTreeLearner::ConstructHistograms(
......@@ -353,13 +354,16 @@ void SerialTreeLearner::ConstructHistograms(
}
void SerialTreeLearner::FindBestSplitsFromHistograms(
const std::vector<int8_t>& is_feature_used, bool use_subtract) {
const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) {
Common::FunctionTimer fun_timer(
"SerialTreeLearner::FindBestSplitsFromHistograms", global_timer);
std::vector<SplitInfo> smaller_best(share_state_->num_threads);
std::vector<SplitInfo> larger_best(share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features = col_sampler_.GetByNode();
std::vector<int8_t> larger_node_used_features = col_sampler_.GetByNode();
std::vector<int8_t> smaller_node_used_features = col_sampler_.GetByNode(tree, smaller_leaf_splits_->leaf_index());
std::vector<int8_t> larger_node_used_features;
if (larger_leaf_splits_->leaf_index() >= 0) {
larger_node_used_features = col_sampler_.GetByNode(tree, larger_leaf_splits_->leaf_index());
}
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
......@@ -437,7 +441,7 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
FindBestSplits();
FindBestSplits(tree);
}
// then, compute own splits
SplitInfo left_split;
......
......@@ -134,11 +134,11 @@ class SerialTreeLearner: public TreeLearner {
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
virtual void FindBestSplits();
virtual void FindBestSplits(const Tree* tree);
virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);
virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree*);
/*!
* \brief Partition tree and data according best split.
......@@ -196,7 +196,6 @@ class SerialTreeLearner: public TreeLearner {
std::unique_ptr<LeafSplits> smaller_leaf_splits_;
/*! \brief stores best thresholds for all feature for larger leaf */
std::unique_ptr<LeafSplits> larger_leaf_splits_;
#ifdef USE_GPU
/*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector<score_t, boost::alignment::aligned_allocator<score_t, 4096>> ordered_gradients_;
......
......@@ -241,7 +241,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vec
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
// use local data to find local best splits
std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static)
......@@ -343,17 +343,17 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(),
output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
this->FindBestSplitsFromHistograms(is_feature_used, false);
this->FindBestSplitsFromHistograms(is_feature_used, false, tree);
}
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool, const Tree* tree) {
std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
std::vector<int8_t> smaller_node_used_features =
this->col_sampler_.GetByNode();
this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
std::vector<int8_t> larger_node_used_features =
this->col_sampler_.GetByNode();
this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
// find best split from local aggregated histograms
OMP_INIT_EX();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment