"include/vscode:/vscode.git/clone" did not exist on "2b8fe8b4bdc00a2611442fdee4c45316f08b1c4b"
Commit ae6ff288 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix tree model format (support multi-cat threshold)

parent 574d7800
......@@ -394,6 +394,7 @@ public:
* \param max_bin max_bin of current used feature
* \param default_bin defualt bin if bin not in [min_bin, max_bin]
* \param threshold The split threshold.
* \param num_threshold Number of threshold
* \param data_indices Used data indices. After called this function. The less than or equal data indices will store on this object.
* \param num_data Number of used data
* \param lte_indices After called this function. The less or equal data indices will store on this object.
......@@ -401,7 +402,7 @@ public:
* \return The number of less than or equal data.
*/
virtual data_size_t SplitCategorical(uint32_t min_bin, uint32_t max_bin,
uint32_t default_bin, uint32_t threshold,
uint32_t default_bin, const uint32_t* threshold, int num_threshold,
data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const = 0;
......
......@@ -224,6 +224,7 @@ public:
int gpu_device_id = -1;
/*! \brief Set to true to use double precision math on GPU (default using single precision) */
bool gpu_use_dp = false;
int max_cat_threshold = 256;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
};
......@@ -461,7 +462,8 @@ struct ParameterAlias {
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename",
"zero_as_missing", "max_cat_threshold",
"init_score_file", "valid_init_score_file", "is_predict_contrib"
});
std::unordered_map<std::string, std::string> tmp_map;
......
......@@ -402,12 +402,12 @@ public:
HistogramBinEntry* data) const;
inline data_size_t Split(int feature,
uint32_t threshold, bool default_left,
const uint32_t* threshold, int num_threshold, bool default_left,
data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const {
const int group = feature2group_[feature];
const int sub_feature = feature2subfeature_[feature];
return feature_groups_[group]->Split(sub_feature, threshold, default_left, data_indices, num_data, lte_indices, gt_indices);
return feature_groups_[group]->Split(sub_feature, threshold, num_threshold, default_left, data_indices, num_data, lte_indices, gt_indices);
}
inline int SubFeatureBinOffset(int i) const {
......
......@@ -160,7 +160,8 @@ public:
inline data_size_t Split(
int sub_feature,
uint32_t threshold,
const uint32_t* threshold,
int num_threshold,
bool default_left,
data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const {
......@@ -171,9 +172,9 @@ public:
if (bin_mappers_[sub_feature]->bin_type() == BinType::NumericalBin) {
auto missing_type = bin_mappers_[sub_feature]->missing_type();
return bin_data_->Split(min_bin, max_bin, default_bin, missing_type, default_left,
threshold, data_indices, num_data, lte_indices, gt_indices);
*threshold, data_indices, num_data, lte_indices, gt_indices);
} else {
return bin_data_->SplitCategorical(min_bin, max_bin, default_bin, threshold, data_indices, num_data, lte_indices, gt_indices);
return bin_data_->SplitCategorical(min_bin, max_bin, default_bin, threshold, num_threshold, data_indices, num_data, lte_indices, gt_indices);
}
}
......
......@@ -60,7 +60,8 @@ public:
* \param real_feature Index of feature, the original index on data
* \param threshold_bin Threshold(bin) of split, use bitset to represent
* \param num_threshold_bin size of threshold_bin
* \param threshold
* \param threshold Thresholds of real feature value, use bitset to represent
* \param num_threshold size of threshold
* \param left_value Model Left child output
* \param right_value Model Right child output
* \param left_cnt Count of left child
......@@ -68,8 +69,8 @@ public:
* \param gain Split gain
* \return The index of new leaf.
*/
int SplitCategorical(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold, double left_value, double right_value,
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type);
/*! \brief Get the output of one leaf */
......@@ -250,14 +251,18 @@ private:
}
int_fval = 0;
}
if (int_fval == static_cast<int>(threshold_[node])) {
int cat_idx = int(threshold_[node]);
if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) {
return left_child_[node];
}
return right_child_[node];
}
inline int CategoricalDecisionInner(uint32_t fval, int node) const {
if (fval == threshold_in_bin_[node]) {
int cat_idx = int(threshold_in_bin_[node]);
if (Common::FindInBitset(cat_threshold_inner_.data() + cat_boundaries_inner_[cat_idx],
cat_boundaries_inner_[cat_idx + 1] - cat_boundaries_inner_[cat_idx], fval)) {
return left_child_[node];
}
return right_child_[node];
......@@ -348,6 +353,10 @@ private:
/*! \brief A non-leaf node's split threshold in feature value */
std::vector<double> threshold_;
int num_cat_;
std::vector<int> cat_boundaries_inner_;
std::vector<uint32_t> cat_threshold_inner_;
std::vector<int> cat_boundaries_;
std::vector<uint32_t> cat_threshold_;
/*! \brief Store the information for categorical feature handle and mising value handle. */
std::vector<int8_t> decision_type_;
/*! \brief A non-leaf node's split gain */
......
......@@ -604,6 +604,30 @@ inline void obtain_min_max_sum(const float *w, int nw, float *mi, float *ma, dou
if (su != nullptr) *su = sumw;
}
template<class T>
inline std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
std::vector<uint32_t> ret;
for (int i = 0; i < n; ++i) {
int i1 = vals[i] / 32;
int i2 = vals[i] % 32;
if (static_cast<int>(ret.size()) < i1 + 1) {
ret.resize(i1 + 1, 0);
}
ret[i1] |= (1 << i2);
}
return ret;
}
template<class T>
inline bool FindInBitset(const uint32_t* bits, int n, T pos) {
int i1 = pos / 32;
if (i1 >= n) {
return false;
}
int i2 = pos % 32;
return (bits[i1] >> i2) & 1;
}
} // namespace Common
} // namespace LightGBM
......
......@@ -23,7 +23,6 @@ bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
if (!boosting->LoadModelFromString(str_buf.str()))
return false;
}
return true;
}
......
......@@ -382,6 +382,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp);
GetInt(params, "max_cat_threshold", &max_cat_threshold);
}
void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
......@@ -417,8 +418,6 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
tree_config.Set(params);
}
void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_machines", &num_machines);
CHECK(num_machines >= 1);
......
......@@ -250,14 +250,14 @@ public:
virtual data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (threshold == default_bin) {
if (Common::FindInBitset(threshold, num_threahold, default_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
......@@ -266,7 +266,7 @@ public:
const uint32_t bin = data_[idx];
if (bin < min_bin || bin > max_bin) {
default_indices[(*default_count)++] = idx;
} else if (bin - min_bin == threshold) {
} else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
......
......@@ -289,14 +289,14 @@ public:
virtual data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
data_size_t lte_count = 0;
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (default_bin == threshold) {
if (Common::FindInBitset(threshold, num_threahold, default_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
......@@ -305,7 +305,7 @@ public:
const uint32_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin < min_bin || bin > max_bin) {
default_indices[(*default_count)++] = idx;
} else if (bin - min_bin == threshold) {
} else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
......
......@@ -206,7 +206,7 @@ public:
virtual data_size_t SplitCategorical(
uint32_t min_bin, uint32_t max_bin, uint32_t default_bin,
uint32_t threshold, data_size_t* data_indices, data_size_t num_data,
const uint32_t* threshold, int num_threahold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
if (num_data <= 0) { return 0; }
data_size_t lte_count = 0;
......@@ -214,7 +214,7 @@ public:
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (default_bin == threshold) {
if (Common::FindInBitset(threshold, num_threahold, default_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
......@@ -223,7 +223,7 @@ public:
uint32_t bin = iterator.InnerRawGet(idx);
if (bin < min_bin || bin > max_bin) {
default_indices[(*default_count)++] = idx;
} else if (bin - min_bin == threshold) {
} else if (Common::FindInBitset(threshold, num_threahold, bin - min_bin)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
......
......@@ -39,6 +39,8 @@ Tree::Tree(int max_leaves)
leaf_parent_[0] = -1;
shrinkage_ = 1.0f;
num_cat_ = 0;
cat_boundaries_.push_back(0);
cat_boundaries_inner_.push_back(0);
}
Tree::~Tree() {
......@@ -66,8 +68,8 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
return num_leaves_ - 1;
}
int Tree::SplitCategorical(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold, double left_value, double right_value,
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
int new_node_idx = num_leaves_ - 1;
......@@ -80,9 +82,17 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, uint32_t thr
} else if (missing_type == MissingType::NaN) {
SetMissingType(&decision_type_[new_node_idx], 2);
}
threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = threshold;
threshold_in_bin_[new_node_idx] = num_cat_;
threshold_[new_node_idx] = num_cat_;
++num_cat_;
cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
for (int i = 0; i < num_threshold; ++i) {
cat_threshold_.push_back(threshold[i]);
}
cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
for (int i = 0; i < num_threshold_bin; ++i) {
cat_threshold_inner_.push_back(threshold_bin[i]);
}
++num_leaves_;
return num_leaves_ - 1;
}
......@@ -219,6 +229,12 @@ std::string Tree::ToString() const {
<< Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
str_buf << "internal_count="
<< Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl;
if (num_cat_ > 0) {
str_buf << "cat_boundaries="
<< Common::ArrayToString<int>(cat_boundaries_, num_cat_ + 1, ' ') << std::endl;
str_buf << "cat_threshold="
<< Common::ArrayToString<uint32_t>(cat_threshold_, cat_threshold_.size(), ' ') << std::endl;
}
str_buf << "shrinkage=" << shrinkage_ << std::endl;
str_buf << std::endl;
return str_buf.str();
......@@ -249,7 +265,18 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"split_feature\":" << split_feature_[index] << "," << std::endl;
str_buf << "\"split_gain\":" << split_gain_[index] << "," << std::endl;
if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
str_buf << "\"threshold\":" << static_cast<int>(threshold_[index]) << "," << std::endl;
int cat_idx = static_cast<int>(threshold_[index]);
std::vector<int> cats;
for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
for (int j = 0; j < 32; ++j) {
int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
cats.push_back(cat);
}
}
}
str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << std::endl;
str_buf << "\"decision_type\":\"==\"," << std::endl;
} else {
str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << std::endl;
......@@ -316,7 +343,11 @@ std::string Tree::CategoricalDecisionIfElse(int node) const {
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}
str_buf << "if (int_fval >= 0 && int_fval == " << static_cast<int>(threshold_[node]) << ") {";
int cat_idx = int(threshold_[node]);
str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
return str_buf.str();
}
......@@ -330,6 +361,14 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
if (num_leaves_ <= 1) {
str_buf << "return " << leaf_value_[0] << ";";
} else {
str_buf << "const std::vector<uint32_t> cat_threshold = {";
for (size_t i = 0; i < cat_threshold_.size(); ++i) {
if (i != 0) {
str_buf << ",";
}
str_buf << cat_threshold_[i];
}
str_buf << "};";
// use this for the missing value conversion
str_buf << "double fval = 0.0f; ";
if (num_cat_ > 0) {
......@@ -459,6 +498,20 @@ Tree::Tree(const std::string& str) {
decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
}
if (num_cat_ > 0) {
if (key_vals.count("cat_boundaries")) {
cat_boundaries_ = Common::StringToArray<int>(key_vals["cat_boundaries"], ' ', num_cat_ + 1);
} else {
Log::Fatal("Tree model should contain cat_boundaries field.");
}
if (key_vals.count("cat_threshold")) {
cat_threshold_ = Common::StringToArray<uint32_t>(key_vals["cat_threshold"], ' ', cat_boundaries_.back());
} else {
Log::Fatal("Tree model should contain cat_threshold field.");
}
}
if (key_vals.count("shrinkage")) {
Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
} else {
......
......@@ -233,7 +233,7 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split);
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold);
// set best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
......
......@@ -91,7 +91,7 @@ public:
* \param threshold threshold that want to split
* \param right_leaf index of right leaf
*/
void Split(int leaf, const Dataset* dataset, int feature, uint32_t threshold, bool default_left, int right_leaf) {
void Split(int leaf, const Dataset* dataset, int feature, const uint32_t* threshold, int num_threshold, bool default_left, int right_leaf) {
const data_size_t min_inner_size = 512;
// get leaf boundary
const data_size_t begin = leaf_begin_[leaf];
......@@ -111,7 +111,7 @@ public:
data_size_t cur_cnt = inner_size;
if (cur_start + cur_cnt > cnt) { cur_cnt = cnt - cur_start; }
// split data inner, reduce the times of function called
data_size_t cur_left_count = dataset->Split(feature, threshold, default_left, indices_.data() + begin + cur_start, cur_cnt,
data_size_t cur_left_count = dataset->Split(feature, threshold, num_threshold, default_left, indices_.data() + begin + cur_start, cur_cnt,
temp_left_indices_.data() + cur_start, temp_right_indices_.data() + cur_start);
offsets_buf_[i] = cur_start;
left_cnts_buf_[i] = cur_left_count;
......
......@@ -15,7 +15,7 @@ class FeatureMetainfo {
public:
int num_bin;
MissingType missing_type;
int bias = 0;
int8_t bias = 0;
uint32_t default_bin;
/*! \brief pointer of tree config */
const TreeConfig* tree_config;
......@@ -105,19 +105,22 @@ public:
SplitInfo* output) {
output->default_left = false;
double best_gain = kMinScore;
uint32_t best_threshold = static_cast<uint32_t>(meta_->num_bin);
data_size_t best_left_count = 0;
double best_sum_left_gradient = 0.0f;
double best_sum_left_hessian = 0.0f;
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
is_splittable_ = false;
const int bias = meta_->bias;
int t = meta_->num_bin - 1 - bias;
const int t_end = 0;
uint32_t best_threshold = 0;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
// from right to left, and we don't need data in bin0
for (; t >= t_end; --t) {
for (int t = 0; t < used_bin; ++t) {
// if data not enough, or sum hessian too small
if (data_[t].cnt < meta_->tree_config->min_data_in_leaf
|| data_[t].sum_hessians < meta_->tree_config->min_sum_hessian_in_leaf) continue;
......@@ -142,51 +145,19 @@ public:
is_splittable_ = true;
// better split point
if (current_gain > best_gain) {
best_threshold = static_cast<uint32_t>(t + bias);
best_threshold = static_cast<uint32_t>(t);
best_sum_left_gradient = data_[t].sum_gradients;
best_sum_left_hessian = data_[t].sum_hessians + kEpsilon;
best_left_count = data_[t].cnt;
best_gain = current_gain;
}
}
// need restore zero bin
if (bias == 1) {
t = meta_->num_bin - 1 - bias;
double sum_bin0_gradient = sum_gradient;
double sum_bin0_hessian = sum_hessian - 2 * kEpsilon;
data_size_t cnt_bin0 = num_data;
for (; t >= 0; --t) {
sum_bin0_gradient -= data_[t].sum_gradients;
sum_bin0_hessian -= data_[t].sum_hessians;
cnt_bin0 -= data_[t].cnt;
}
data_size_t other_count = num_data - cnt_bin0;
double sum_other_hessian = sum_hessian - sum_bin0_hessian - kEpsilon;
if (cnt_bin0 >= meta_->tree_config->min_data_in_leaf
&& sum_bin0_hessian >= meta_->tree_config->min_sum_hessian_in_leaf
&& other_count >= meta_->tree_config->min_data_in_leaf
&& sum_other_hessian >= meta_->tree_config->min_sum_hessian_in_leaf) {
double sum_other_gradient = sum_gradient - sum_bin0_gradient;
double current_gain = GetLeafSplitGain(sum_other_gradient, sum_other_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2)
+ GetLeafSplitGain(sum_bin0_gradient, sum_bin0_hessian + kEpsilon,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2);
if (current_gain > min_gain_shift) {
is_splittable_ = true;
// better split point
if (current_gain > best_gain) {
best_threshold = static_cast<uint32_t>(0);
best_sum_left_gradient = sum_bin0_gradient;
best_sum_left_hessian = sum_bin0_hessian + kEpsilon;
best_left_count = cnt_bin0;
best_gain = current_gain;
}
}
}
}
if (is_splittable_) {
// update split information
output->threshold = best_threshold;
output->num_cat_threshold = 1;
output->cat_threshold.resize(output->num_cat_threshold);
output->cat_threshold[0] = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2);
output->left_count = best_left_count;
......@@ -199,7 +170,7 @@ public:
output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
output->gain = best_gain - min_gain_shift;
}
}
}
/*!
......@@ -258,7 +229,7 @@ private:
void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_gain_shift,
SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing) {
const int bias = meta_->bias;
const int8_t bias = meta_->bias;
double best_sum_left_gradient = NAN;
double best_sum_left_hessian = NAN;
......
......@@ -22,8 +22,8 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank();
num_machines_ = Network::num_machines();
input_buffer_.resize(sizeof(SplitInfo) * 2);
output_buffer_.resize(sizeof(SplitInfo) * 2);
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->tree_config_->max_cat_threshold) * 2);
}
......@@ -60,7 +60,7 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(con
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split);
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold);
// update best split
this->best_split_per_leaf_[this->smaller_leaf_splits_->LeafIndex()] = smaller_best_split;
if (this->larger_leaf_splits_->LeafIndex() >= 0) {
......
......@@ -181,9 +181,9 @@ private:
};
// To-do: reduce the communication cost by using bitset to communicate.
inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, SplitInfo* smaller_best_split, SplitInfo* larger_best_split) {
inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, SplitInfo* smaller_best_split, SplitInfo* larger_best_split, int max_cat_threshold) {
// sync global best info
int size = SplitInfo::Size();
int size = SplitInfo::Size(max_cat_threshold);
smaller_best_split->CopyTo(input_buffer_);
larger_best_split->CopyTo(input_buffer_ + size);
Network::Allreduce(input_buffer_, size * 2, size, output_buffer_,
......
......@@ -522,12 +522,13 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
// left = parent
*left_leaf = best_leaf;
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) {
auto threshold_double = train_data_->RealThreshold(inner_feature_index, best_split_info.threshold);
// split tree, will return right leaf
*right_leaf = tree->Split(best_leaf,
inner_feature_index,
best_split_info.feature,
best_split_info.threshold,
train_data_->RealThreshold(inner_feature_index, best_split_info.threshold),
threshold_double,
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
......@@ -535,22 +536,35 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left);
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
&best_split_info.threshold, 1, best_split_info.default_left, *right_leaf);
} else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(best_split_info.cat_threshold.data(), best_split_info.num_cat_threshold);
std::vector<int> threshold_int(best_split_info.num_cat_threshold);
for (int i = 0; i < best_split_info.num_cat_threshold; ++i) {
threshold_int[i] = static_cast<int>(train_data_->RealThreshold(inner_feature_index, best_split_info.cat_threshold[i]));
}
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(threshold_int.data(), best_split_info.num_cat_threshold);
*right_leaf = tree->SplitCategorical(best_leaf,
inner_feature_index,
best_split_info.feature,
best_split_info.threshold,
train_data_->RealThreshold(inner_feature_index, best_split_info.threshold),
cat_bitset_inner.data(),
static_cast<int>(cat_bitset_inner.size()),
cat_bitset.data(),
static_cast<int>(cat_bitset.size()),
static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()), best_split_info.default_left, *right_leaf);
}
// split data partition
data_partition_->Split(best_leaf, train_data_, inner_feature_index,
best_split_info.threshold, best_split_info.default_left, *right_leaf);
#ifdef DEBUG
CHECK(best_split_info.left_count == data_partition_->leaf_count(best_leaf));
#endif
// init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) {
......
......@@ -24,6 +24,7 @@ public:
data_size_t left_count = 0;
/*! \brief Right number of data after split */
data_size_t right_count = 0;
int num_cat_threshold = 0;
/*! \brief Left output after split */
double left_output = 0.0;
/*! \brief Right output after split */
......@@ -38,11 +39,12 @@ public:
double right_sum_gradient = 0;
/*! \brief Right sum hessian after split */
double right_sum_hessian = 0;
std::vector<uint32_t> cat_threshold;
/*! \brief True if default split is left */
bool default_left = true;
inline static int Size() {
return sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2;
inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t);
}
inline void CopyTo(char* buffer) const {
......@@ -70,6 +72,10 @@ public:
buffer += sizeof(right_sum_hessian);
std::memcpy(buffer, &default_left, sizeof(default_left));
buffer += sizeof(default_left);
for (int i = 0; i < num_cat_threshold; ++i) {
std::memcpy(buffer, &cat_threshold[i], sizeof(uint32_t));
buffer += sizeof(uint32_t);
}
}
void CopyFrom(const char* buffer) {
......@@ -97,6 +103,11 @@ public:
buffer += sizeof(right_sum_hessian);
std::memcpy(&default_left, buffer, sizeof(default_left));
buffer += sizeof(default_left);
cat_threshold.resize(num_cat_threshold);
for (int i = 0; i < num_cat_threshold; ++i) {
std::memcpy(&cat_threshold[i], buffer, sizeof(uint32_t));
buffer += sizeof(uint32_t);
}
}
inline void Reset() {
......
......@@ -442,7 +442,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->LeafIndex()];
}
// sync global best info
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split);
SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->tree_config_->max_cat_threshold);
// copy back
this->best_split_per_leaf_[smaller_leaf_splits_global_->LeafIndex()] = smaller_best_split;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment