Commit 84fef715 authored by Jerry Liu's avatar Jerry Liu Committed by Guolin Ke
Browse files

add force_split functionality (#1310)

parent 71539cc2
......@@ -144,7 +144,7 @@ if(USE_MPI)
include_directories(${MPI_CXX_INCLUDE_PATH})
endif(USE_MPI)
file(GLOB SOURCES
file(GLOB SOURCES
src/application/*.cpp
src/boosting/*.cpp
src/io/*.cpp
......
......@@ -520,6 +520,16 @@ IO Parameters
- separate by ``,`` for multi-validation data
- ``forced_splits``, default=\ ``""``, type=string
- path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences.
- ``.json`` file can be arbitrarily nested, and each split contains ``feature``, ``threshold`` fields, as well as ``left`` and ``right``
fields representing subsplits. Categorical splits are forced in a one-hot fashion, with ``left`` representing the split containing
the feature value and ``right`` representing other values.
- see ``examples/binary_classification/forced_splits.json`` as an example.
Objective Parameters
--------------------
......
{
"feature": 25,
"threshold": 1.30,
"left": {
"feature": 26,
"threshold": 0.85
},
"right": {
"feature": 26,
"threshold": 0.85
}
}
......@@ -109,3 +109,6 @@ local_listen_port = 12400
# machines list file for parallel training, alias: mlist
machine_list_file = mlist.txt
# # force splits
# forced_splits = forced_splits.json
......@@ -105,6 +105,7 @@ public:
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
std::string input_model = "";
int verbosity = 1;
int num_iteration_predict = -1;
bool is_pre_partition = false;
......@@ -264,6 +265,9 @@ public:
std::string device_type = kDefaultDevice;
TreeConfig tree_config;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
/* filename of forced splits */
std::string forcedsplits_filename = "";
};
/*! \brief Config for Network */
......@@ -482,7 +486,8 @@ struct ParameterAlias {
"histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot",
"alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints", "max_delta_step"
"alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints", "max_delta_step",
"forced_splits"
});
std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) {
......
......@@ -495,6 +495,13 @@ public:
return feature_groups_[group]->bin_mappers_[sub_feature]->BinToValue(threshold);
}
// given a real threshold, find the closest threshold bin
inline uint32_t BinThreshold(int i, double threshold_double) const {
const int group = feature2group_[i];
const int sub_feature = feature2subfeature_[i];
return feature_groups_[group]->bin_mappers_[sub_feature]->ValueToBin(threshold_double);
}
inline void CreateOrderedBins(std::vector<std::unique_ptr<OrderedBin>>* ordered_bins) const {
ordered_bins->resize(num_groups_);
OMP_INIT_EX();
......
/* json11
*
* json11 is a tiny JSON library for C++11, providing JSON parsing and serialization.
*
* The core object provided by the library is json11::Json. A Json object represents any JSON
* value: null, bool, number (int or double), string (std::string), array (std::vector), or
* object (std::map).
*
* Json objects act like values: they can be assigned, copied, moved, compared for equality or
* order, etc. There are also helper methods Json::dump, to serialize a Json to a string, and
* Json::parse (static) to parse a std::string as a Json object.
*
* Internally, the various types of Json object are represented by the JsonValue class
* hierarchy.
*
* A note on numbers - JSON specifies the syntax of number formatting but not its semantics,
* so some JSON implementations distinguish between integers and floating-point numbers, while
* some don't. In json11, we choose the latter. Because some JSON implementations (namely
* Javascript itself) treat all numbers as the same type, distinguishing the two leads
* to JSON that will be *silently* changed by a round-trip through those implementations.
* Dangerous! To avoid that risk, json11 stores all numbers as double internally, but also
* provides integer helpers.
*
* Fortunately, double-precision IEEE754 ('double') can precisely store any integer in the
* range +/-2^53, which includes every 'int' on most systems. (Timestamps often use int64
* or long long to avoid the Y2038K problem; a double storing microseconds since some epoch
* will be exact for +/- 275 years.)
*/
/* Copyright (c) 2013 Dropbox, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#pragma once
#include <string>
#include <vector>
#include <map>
#include <memory>
#include <initializer_list>
#ifdef _MSC_VER
#if _MSC_VER <= 1800 // VS 2013
#ifndef noexcept
#define noexcept throw()
#endif
#ifndef snprintf
#define snprintf _snprintf_s
#endif
#endif
#endif
namespace json11 {
enum JsonParse {
STANDARD, COMMENTS
};
class JsonValue;
class Json final {
public:
// Types
enum Type {
NUL, NUMBER, BOOL, STRING, ARRAY, OBJECT
};
// Array and object typedefs
typedef std::vector<Json> array;
typedef std::map<std::string, Json> object;
// Constructors for the various types of JSON value.
Json() noexcept; // NUL
Json(std::nullptr_t) noexcept; // NUL
Json(double value); // NUMBER
Json(int value); // NUMBER
Json(bool value); // BOOL
Json(const std::string &value); // STRING
Json(std::string &&value); // STRING
Json(const char * value); // STRING
Json(const array &values); // ARRAY
Json(array &&values); // ARRAY
Json(const object &values); // OBJECT
Json(object &&values); // OBJECT
// Implicit constructor: anything with a to_json() function.
template <class T, class = decltype(&T::to_json)>
Json(const T & t) : Json(t.to_json()) {}
// Implicit constructor: map-like objects (std::map, std::unordered_map, etc)
template <class M, typename std::enable_if<
std::is_constructible<std::string, decltype(std::declval<M>().begin()->first)>::value
&& std::is_constructible<Json, decltype(std::declval<M>().begin()->second)>::value,
int>::type = 0>
Json(const M & m) : Json(object(m.begin(), m.end())) {}
// Implicit constructor: vector-like objects (std::list, std::vector, std::set, etc)
template <class V, typename std::enable_if<
std::is_constructible<Json, decltype(*std::declval<V>().begin())>::value,
int>::type = 0>
Json(const V & v) : Json(array(v.begin(), v.end())) {}
// This prevents Json(some_pointer) from accidentally producing a bool. Use
// Json(bool(some_pointer)) if that behavior is desired.
Json(void *) = delete;
// Accessors
Type type() const;
bool is_null() const { return type() == NUL; }
bool is_number() const { return type() == NUMBER; }
bool is_bool() const { return type() == BOOL; }
bool is_string() const { return type() == STRING; }
bool is_array() const { return type() == ARRAY; }
bool is_object() const { return type() == OBJECT; }
// Return the enclosed value if this is a number, 0 otherwise. Note that json11 does not
// distinguish between integer and non-integer numbers - number_value() and int_value()
// can both be applied to a NUMBER-typed object.
double number_value() const;
int int_value() const;
// Return the enclosed value if this is a boolean, false otherwise.
bool bool_value() const;
// Return the enclosed string if this is a string, "" otherwise.
const std::string &string_value() const;
// Return the enclosed std::vector if this is an array, or an empty vector otherwise.
const array &array_items() const;
// Return the enclosed std::map if this is an object, or an empty map otherwise.
const object &object_items() const;
// Return a reference to arr[i] if this is an array, Json() otherwise.
const Json & operator[](size_t i) const;
// Return a reference to obj[key] if this is an object, Json() otherwise.
const Json & operator[](const std::string &key) const;
// Serialize.
void dump(std::string &out) const;
std::string dump() const {
std::string out;
dump(out);
return out;
}
// Parse. If parse fails, return Json() and assign an error message to err.
static Json parse(const std::string & in,
std::string & err,
JsonParse strategy = JsonParse::STANDARD);
static Json parse(const char * in,
std::string & err,
JsonParse strategy = JsonParse::STANDARD) {
if (in) {
return parse(std::string(in), err, strategy);
} else {
err = "null input";
return nullptr;
}
}
// Parse multiple objects, concatenated or separated by whitespace
static std::vector<Json> parse_multi(
const std::string & in,
std::string::size_type & parser_stop_pos,
std::string & err,
JsonParse strategy = JsonParse::STANDARD);
static inline std::vector<Json> parse_multi(
const std::string & in,
std::string & err,
JsonParse strategy = JsonParse::STANDARD) {
std::string::size_type parser_stop_pos;
return parse_multi(in, parser_stop_pos, err, strategy);
}
bool operator== (const Json &rhs) const;
bool operator< (const Json &rhs) const;
bool operator!= (const Json &rhs) const { return !(*this == rhs); }
bool operator<= (const Json &rhs) const { return !(rhs < *this); }
bool operator> (const Json &rhs) const { return (rhs < *this); }
bool operator>= (const Json &rhs) const { return !(*this < rhs); }
/* has_shape(types, err)
*
* Return true if this is a JSON object and, for each item in types, has a field of
* the given type. If not, return false and set err to a descriptive message.
*/
typedef std::initializer_list<std::pair<std::string, Type>> shape;
bool has_shape(const shape & types, std::string & err) const;
private:
std::shared_ptr<JsonValue> m_ptr;
};
// Internal class hierarchy - JsonValue objects are not exposed to users of this API.
class JsonValue {
protected:
friend class Json;
friend class JsonInt;
friend class JsonDouble;
virtual Json::Type type() const = 0;
virtual bool equals(const JsonValue * other) const = 0;
virtual bool less(const JsonValue * other) const = 0;
virtual void dump(std::string &out) const = 0;
virtual double number_value() const;
virtual int int_value() const;
virtual bool bool_value() const;
virtual const std::string &string_value() const;
virtual const Json::array &array_items() const;
virtual const Json &operator[](size_t i) const;
virtual const Json::object &object_items() const;
virtual const Json &operator[](const std::string &key) const;
virtual ~JsonValue() {}
};
} // namespace json11
......@@ -4,9 +4,12 @@
#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include <LightGBM/json11.hpp>
#include <vector>
using namespace json11;
namespace LightGBM {
/*! \brief forward declaration */
......@@ -44,7 +47,8 @@ public:
* \param is_constant_hessian True if all hessians share the same value
* \return A trained tree
*/
virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_constant_hessian) = 0;
virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_constant_hessian,
Json& forced_split_json) = 0;
/*!
* \brief use a existing tree to fit the new gradients and hessians.
......
......@@ -32,7 +32,8 @@ public:
* \param training_metrics Training metrics
* \param output_model_filename Filename of output model
*/
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void Init(const BoostingConfig* config, const Dataset* train_data,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed);
......
......@@ -3,7 +3,6 @@
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
......@@ -75,6 +74,16 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
std::string forced_splits_path = config->forcedsplits_filename;
//load forced_splits file
if (forced_splits_path != "") {
std::ifstream forced_splits_file(forced_splits_path.c_str());
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), err);
}
objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) {
......@@ -425,7 +434,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
hess = hessians_.data() + bias;
}
new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_));
new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_, forced_splits_json_));
}
#ifdef TIMETAG
......
......@@ -4,6 +4,7 @@
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/json11.hpp>
#include "score_updater.hpp"
......@@ -15,6 +16,8 @@
#include <mutex>
#include <map>
using namespace json11;
namespace LightGBM {
/*!
......@@ -40,7 +43,8 @@ public:
* \param objective_function Training objective function
* \param training_metrics Training metrics
*/
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void Init(const BoostingConfig* gbdt_config, const Dataset* train_data,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override;
/*!
......@@ -452,6 +456,8 @@ protected:
std::unique_ptr<ObjectiveFunction> loaded_objective_;
bool average_output_;
bool need_re_bagging_;
Json forced_splits_json_;
};
} // namespace LightGBM
......
......@@ -112,7 +112,8 @@ public:
hess = tmp_hess_.data() + bias;
}
new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_));
new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_,
forced_splits_json_));
}
if (new_tree->num_leaves() > 1) {
......
......@@ -466,6 +466,7 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetBool(params, "boost_from_average", &boost_from_average);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type);
GetString(params, "forced_splits", &forcedsplits_filename);
tree_config.Set(params);
}
......
This diff is collapsed.
......@@ -7,6 +7,7 @@
#include <LightGBM/dataset.h>
#include <cstring>
#include <cmath>
namespace LightGBM
{
......@@ -20,6 +21,7 @@ public:
int8_t monotone_type;
/*! \brief pointer of tree config */
const TreeConfig* tree_config;
BinType bin_type;
};
/*!
* \brief FeatureHistogram is used to construct and store a histogram for a feature.
......@@ -43,10 +45,10 @@ public:
* \param feature the feature data for this histogram
* \param min_num_data_one_leaf minimal number of data in one leaf
*/
void Init(HistogramBinEntry* data, const FeatureMetainfo* meta, BinType bin_type) {
void Init(HistogramBinEntry* data, const FeatureMetainfo* meta) {
meta_ = meta;
data_ = data;
if (bin_type == BinType::NumericalBin) {
if (meta_->bin_type == BinType::NumericalBin) {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1
, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6);
} else {
......@@ -105,7 +107,8 @@ public:
output->max_constraint = max_constraint;
}
void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data,
double min_constraint, double max_constraint,
SplitInfo* output) {
output->default_left = false;
double best_gain = kMinScore;
......@@ -267,6 +270,149 @@ public:
}
}
void GatherInfoForThreshold(double sum_gradient, double sum_hessian,
uint32_t threshold, data_size_t num_data, SplitInfo *output) {
if (meta_->bin_type == BinType::NumericalBin) {
GatherInfoForThresholdNumerical(sum_gradient, sum_hessian, threshold,
num_data, output);
} else {
GatherInfoForThresholdCategorical(sum_gradient, sum_hessian, threshold,
num_data, output);
}
}
void GatherInfoForThresholdNumerical(double sum_gradient, double sum_hessian,
uint32_t threshold, data_size_t num_data,
SplitInfo *output) {
double gain_shift = GetLeafSplitGain(sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
// do stuff here
const int8_t bias = meta_->bias;
double sum_right_gradient = 0.0f;
double sum_right_hessian = kEpsilon;
data_size_t right_count = 0;
// set values
bool use_na_as_missing;
bool skip_default_bin;
if (meta_->missing_type == MissingType::Zero) {
skip_default_bin = true;
use_na_as_missing = false;
} else {
skip_default_bin = false;
use_na_as_missing = true;
}
int t = meta_->num_bin - 1 - bias - use_na_as_missing;
const int t_end = 1 - bias;
// from right to left, and we don't need data in bin0
for (; t >= t_end; --t) {
if (static_cast<uint32_t>(t + bias) < threshold) { break; }
// need to skip default bin
if (skip_default_bin && (t + bias) == static_cast<int>(meta_->default_bin)) { continue; }
sum_right_gradient += data_[t].sum_gradients;
sum_right_hessian += data_[t].sum_hessians;
right_count += data_[t].cnt;
}
double sum_left_gradient = sum_gradient - sum_right_gradient;
double sum_left_hessian = sum_hessian - sum_right_hessian;
data_size_t left_count = num_data - right_count;
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step)
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
// gain with split is worse than without split
if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
output->gain = kMinScore;
Log::Warning("Gain with forced split worse than without split");
return;
};
// update split information
output->threshold = threshold;
output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
output->left_count = left_count;
output->left_sum_gradient = sum_left_gradient;
output->left_sum_hessian = sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - sum_left_gradient,
sum_hessian - sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
output->right_count = num_data - left_count;
output->right_sum_gradient = sum_gradient - sum_left_gradient;
output->right_sum_hessian = sum_hessian - sum_left_hessian - kEpsilon;
output->gain = current_gain;
output->gain -= min_gain_shift;
output->default_left = true;
}
void GatherInfoForThresholdCategorical(double sum_gradient, double sum_hessian,
uint32_t threshold, data_size_t num_data, SplitInfo *output) {
// get SplitInfo for a given one-hot categorical split.
output->default_left = false;
double gain_shift = GetLeafSplitGain(
sum_gradient, sum_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
meta_->tree_config->max_delta_step);
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
if (threshold >= static_cast<uint32_t>(used_bin)) {
output->gain = kMinScore;
Log::Warning("Invalid categorical threshold split");
return;
}
double l2 = meta_->tree_config->lambda_l2;
data_size_t left_count = data_[threshold].cnt;
data_size_t right_count = num_data - left_count;
double sum_left_hessian = data_[threshold].sum_hessians + kEpsilon;
double sum_right_hessian = sum_hessian - sum_left_hessian;
double sum_left_gradient = data_[threshold].sum_gradients;
double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain
double current_gain = GetLeafSplitGain(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step)
+ GetLeafSplitGain(sum_left_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step);
if (std::isnan(current_gain) || current_gain <= min_gain_shift) {
output->gain = kMinScore;
Log::Warning("Gain with forced split worse than without split");
return;
}
output->left_output = CalculateSplittedLeafOutput(sum_left_gradient, sum_left_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step);
output->left_count = left_count;
output->left_sum_gradient = sum_left_gradient;
output->left_sum_hessian = sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, l2,
meta_->tree_config->max_delta_step);
output->right_count = right_count;
output->right_sum_gradient = sum_gradient - sum_left_gradient;
output->right_sum_hessian = sum_right_hessian - kEpsilon;
output->gain = current_gain - min_gain_shift;
output->num_cat_threshold = 1;
output->cat_threshold = std::vector<uint32_t>(1, threshold);
}
/*!
* \brief Binary size of this histogram
*/
......@@ -500,7 +646,6 @@ private:
/*! \brief sum of gradient of each bin */
HistogramBinEntry* data_;
//std::vector<HistogramBinEntry> data_;
/*! \brief False if this histogram cannot split */
bool is_splittable_ = true;
std::function<void(double, double, data_size_t, double, double, SplitInfo*)> find_best_threshold_fun_;
......@@ -568,6 +713,7 @@ public:
feature_metas_[i].bias = 0;
}
feature_metas_[i].tree_config = tree_config;
feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type();
}
}
uint64_t num_total_bin = train_data->NumTotalBin();
......@@ -589,7 +735,7 @@ public:
uint64_t offset = 0;
for (int j = 0; j < train_data->num_features(); ++j) {
offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j));
pool_[i][j].Init(data_[i].data() + offset, &feature_metas_[j], train_data->FeatureBinMapper(j)->bin_type());
pool_[i][j].Init(data_[i].data() + offset, &feature_metas_[j]);
auto num_bin = train_data->FeatureNumBin(j);
if (train_data->FeatureBinMapper(j)->GetDefaultBin() == 0) {
num_bin -= 1;
......
......@@ -751,7 +751,8 @@ void GPUTreeLearner::InitGPU(int platform_id, int device_id) {
SetupKernelArguments();
}
Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian) {
Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians,
bool is_constant_hessian, Json& forced_split_json) {
// check if we need to recompile the GPU kernel (is_constant_hessian changed)
// this should rarely occur
if (is_constant_hessian != is_constant_hessian_) {
......@@ -760,7 +761,7 @@ Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians, b
BuildGPUKernels();
SetupKernelArguments();
}
return SerialTreeLearner::Train(gradients, hessians, is_constant_hessian);
return SerialTreeLearner::Train(gradients, hessians, is_constant_hessian, forced_split_json);
}
void GPUTreeLearner::ResetTrainingData(const Dataset* train_data) {
......
......@@ -28,6 +28,7 @@
#include <boost/compute/container/vector.hpp>
#include <boost/align/aligned_allocator.hpp>
using namespace json11;
namespace LightGBM {
......@@ -40,7 +41,8 @@ public:
~GPUTreeLearner();
void Init(const Dataset* train_data, bool is_constant_hessian) override;
void ResetTrainingData(const Dataset* train_data) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian) override;
Tree* Train(const score_t* gradients, const score_t *hessians,
bool is_constant_hessian, Json& forced_split_json) override;
void SetBaggingData(const data_size_t* used_indices, data_size_t num_data) override {
SerialTreeLearner::SetBaggingData(used_indices, num_data);
......
......@@ -6,6 +6,7 @@
#include <algorithm>
#include <vector>
#include <queue>
namespace LightGBM {
......@@ -152,7 +153,7 @@ void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
histogram_pool_.ResetConfig(tree_config_);
}
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian) {
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian, Json& forced_split_json) {
gradients_ = gradients;
hessians_ = hessians;
is_constant_hessian_ = is_constant_hessian;
......@@ -172,18 +173,29 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
int cur_depth = 1;
// only root leaf can be splitted on first time
int right_leaf = -1;
for (int split = 0; split < tree_config_->num_leaves - 1; ++split) {
int init_splits = 0;
bool aborted_last_force_split = false;
if (!forced_split_json.is_null()) {
init_splits = ForceSplits(tree.get(), forced_split_json, &left_leaf,
&right_leaf, &cur_depth, &aborted_last_force_split);
}
for (int split = init_splits; split < tree_config_->num_leaves - 1; ++split) {
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
#endif
// some initial works before finding best split
if (BeforeFindBestSplit(tree.get(), left_leaf, right_leaf)) {
if (!aborted_last_force_split && BeforeFindBestSplit(tree.get(), left_leaf, right_leaf)) {
#ifdef TIMETAG
init_split_time += std::chrono::steady_clock::now() - start_time;
#endif
// find best threshold for every feature
FindBestSplits();
} else if (aborted_last_force_split) {
aborted_last_force_split = false;
}
// Get a leaf with max split gain
int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
// Get split information for best leaf
......@@ -528,6 +540,162 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
#endif
}
int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int* left_leaf,
int* right_leaf, int *cur_depth,
bool *aborted_last_force_split) {
int32_t result_count = 0;
// start at root leaf
*left_leaf = 0;
std::queue<std::pair<Json, int>> q;
Json left = forced_split_json;
Json right;
bool left_smaller = true;
std::unordered_map<int, SplitInfo> forceSplitMap;
q.push(std::make_pair(forced_split_json, *left_leaf));
while(!q.empty()) {
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
FindBestSplits();
}
// then, compute own splits
SplitInfo left_split;
SplitInfo right_split;
if (!left.is_null()) {
const int left_feature = left["feature"].int_value();
const double left_threshold_double = left["threshold"].number_value();
const int left_inner_feature_index = train_data_->InnerFeatureIndex(left_feature);
const uint32_t left_threshold = train_data_->BinThreshold(
left_inner_feature_index, left_threshold_double);
auto leaf_histogram_array = (left_smaller) ? smaller_leaf_histogram_array_ : larger_leaf_histogram_array_;
auto left_leaf_splits = (left_smaller) ? smaller_leaf_splits_.get() : larger_leaf_splits_.get();
leaf_histogram_array[left_inner_feature_index].GatherInfoForThreshold(
left_leaf_splits->sum_gradients(),
left_leaf_splits->sum_hessians(),
left_threshold,
left_leaf_splits->num_data_in_leaf(),
&left_split);
left_split.feature = left_feature;
forceSplitMap[*left_leaf] = left_split;
if (left_split.gain < 0) {
forceSplitMap.erase(*left_leaf);
}
}
if (!right.is_null()) {
const int right_feature = right["feature"].int_value();
const double right_threshold_double = right["threshold"].number_value();
const int right_inner_feature_index = train_data_->InnerFeatureIndex(right_feature);
const uint32_t right_threshold = train_data_->BinThreshold(
right_inner_feature_index, right_threshold_double);
auto leaf_histogram_array = (left_smaller) ? larger_leaf_histogram_array_ : smaller_leaf_histogram_array_;
auto right_leaf_splits = (left_smaller) ? larger_leaf_splits_.get() : smaller_leaf_splits_.get();
leaf_histogram_array[right_inner_feature_index].GatherInfoForThreshold(
right_leaf_splits->sum_gradients(),
right_leaf_splits->sum_hessians(),
right_threshold,
right_leaf_splits->num_data_in_leaf(),
&right_split);
right_split.feature = right_feature;
forceSplitMap[*right_leaf] = right_split;
if (right_split.gain < 0) {
forceSplitMap.erase(*right_leaf);
}
}
std::pair<Json, int> pair = q.front();
q.pop();
int current_leaf = pair.second;
// split info should exist because searching in bfs fashion - should have added from parent
if (forceSplitMap.find(current_leaf) == forceSplitMap.end()) {
*aborted_last_force_split = true;
break;
}
SplitInfo current_split_info = forceSplitMap[current_leaf];
const int inner_feature_index = train_data_->InnerFeatureIndex(
current_split_info.feature);
auto threshold_double = train_data_->RealThreshold(
inner_feature_index, current_split_info.threshold);
// split tree, will return right leaf
*left_leaf = current_leaf;
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) {
*right_leaf = tree->Split(current_leaf,
inner_feature_index,
current_split_info.feature,
current_split_info.threshold,
threshold_double,
static_cast<double>(current_split_info.left_output),
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left);
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
&current_split_info.threshold, 1,
current_split_info.default_left, *right_leaf);
} else {
std::vector<uint32_t> cat_bitset_inner = Common::ConstructBitset(
current_split_info.cat_threshold.data(), current_split_info.num_cat_threshold);
std::vector<int> threshold_int(current_split_info.num_cat_threshold);
for (int i = 0; i < current_split_info.num_cat_threshold; ++i) {
threshold_int[i] = static_cast<int>(train_data_->RealThreshold(
inner_feature_index, current_split_info.cat_threshold[i]));
}
std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
threshold_int.data(), current_split_info.num_cat_threshold);
*right_leaf = tree->SplitCategorical(current_leaf,
inner_feature_index,
current_split_info.feature,
cat_bitset_inner.data(),
static_cast<int>(cat_bitset_inner.size()),
cat_bitset.data(),
static_cast<int>(cat_bitset.size()),
static_cast<double>(current_split_info.left_output),
static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count),
static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(current_leaf, train_data_, inner_feature_index,
cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
current_split_info.default_left, *right_leaf);
}
if (current_split_info.left_count < current_split_info.right_count) {
left_smaller = true;
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
current_split_info.left_sum_gradient,
current_split_info.left_sum_hessian);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
current_split_info.right_sum_gradient,
current_split_info.right_sum_hessian);
} else {
left_smaller = false;
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
current_split_info.right_sum_gradient, current_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
current_split_info.left_sum_gradient, current_split_info.left_sum_hessian);
}
left = Json();
right = Json();
if ((pair.first).object_items().count("left") > 0) {
left = (pair.first)["left"];
q.push(std::make_pair(left, *left_leaf));
}
if ((pair.first).object_items().count("right") > 0) {
right = (pair.first)["right"];
q.push(std::make_pair(right, *right_leaf));
}
result_count++;
*(cur_depth) = std::max(*(cur_depth), tree->leaf_depth(*left_leaf));
}
return result_count;
}
void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf) {
const SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
......
......@@ -24,6 +24,8 @@
#include <boost/align/aligned_allocator.hpp>
#endif
using namespace json11;
namespace LightGBM {
/*!
......@@ -41,7 +43,8 @@ public:
void ResetConfig(const TreeConfig* tree_config) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian) override;
Tree* Train(const score_t* gradients, const score_t *hessians, bool is_constant_hessian,
Json& forced_split_json) override;
Tree* FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t* hessians) const override;
......@@ -95,6 +98,12 @@ protected:
*/
virtual void Split(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf);
/* Force splits with forced_split_json dict and then return num splits forced.*/
virtual int32_t ForceSplits(Tree* tree, Json& forced_split_json, int* left_leaf,
int* right_leaf, int* cur_depth,
bool *aborted_last_force_split);
/*!
* \brief Get the number of data in a leaf
* \param leaf_idx The index of leaf
......
......@@ -76,12 +76,13 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
feature_metas_[i].bias = 0;
}
feature_metas_[i].tree_config = this->tree_config_;
feature_metas_[i].bin_type = train_data->FeatureBinMapper(i)->bin_type();
}
uint64_t offset = 0;
for (int j = 0; j < train_data->num_features(); ++j) {
offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j));
smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offset, &feature_metas_[j], train_data->FeatureBinMapper(j)->bin_type());
larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offset, &feature_metas_[j], train_data->FeatureBinMapper(j)->bin_type());
smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
auto num_bin = train_data->FeatureNumBin(j);
if (train_data->FeatureBinMapper(j)->GetDefaultBin() == 0) {
num_bin -= 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment