Commit e1d7a7b9 authored by Guolin Ke's avatar Guolin Ke Committed by Nikita Titov
Browse files

add weight in tree model output (#2269)

* add weight in tree model output

* fix bug

* updated Python plotting part to handle weights
parent 86a95783
...@@ -355,7 +355,9 @@ ...@@ -355,7 +355,9 @@
" 'split_gain',\n", " 'split_gain',\n",
" 'internal_value',\n", " 'internal_value',\n",
" 'internal_count',\n", " 'internal_count',\n",
" 'leaf_count'],\n", " 'internal_weight',\n",
" 'leaf_count',\n",
" 'leaf_weight'],\n",
" value=['None']),\n", " value=['None']),\n",
" precision=(0, 10))\n", " precision=(0, 10))\n",
" tree = None\n", " tree = None\n",
...@@ -382,7 +384,7 @@ ...@@ -382,7 +384,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.3" "version": "3.7.1"
}, },
"varInspector": { "varInspector": {
"cols": { "cols": {
......
...@@ -50,6 +50,8 @@ class Tree { ...@@ -50,6 +50,8 @@ class Tree {
* \param right_value Model Right child output * \param right_value Model Right child output
* \param left_cnt Count of left child * \param left_cnt Count of left child
* \param right_cnt Count of right child * \param right_cnt Count of right child
* \param left_weight Weight of left child
* \param right_weight Weight of right child
* \param gain Split gain * \param gain Split gain
* \param missing_type missing type * \param missing_type missing type
* \param default_left default direction for missing value * \param default_left default direction for missing value
...@@ -57,7 +59,8 @@ class Tree { ...@@ -57,7 +59,8 @@ class Tree {
*/ */
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value, double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left); int left_cnt, int right_cnt, double left_weight, double right_weight,
float gain, MissingType missing_type, bool default_left);
/*! /*!
* \brief Performing a split on tree leaves, with categorical feature * \brief Performing a split on tree leaves, with categorical feature
...@@ -72,12 +75,14 @@ class Tree { ...@@ -72,12 +75,14 @@ class Tree {
* \param right_value Model Right child output * \param right_value Model Right child output
* \param left_cnt Count of left child * \param left_cnt Count of left child
* \param right_cnt Count of right child * \param right_cnt Count of right child
* \param left_weight Weight of left child
* \param right_weight Weight of right child
* \param gain Split gain * \param gain Split gain
* \return The index of new leaf. * \return The index of new leaf.
*/ */
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin, int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value, const uint32_t* threshold, int num_threshold, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type); int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type);
/*! \brief Get the output of one leaf */ /*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; } inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
...@@ -297,8 +302,8 @@ class Tree { ...@@ -297,8 +302,8 @@ class Tree {
} }
} }
inline void Split(int leaf, int feature, int real_feature, inline void Split(int leaf, int feature, int real_feature, double left_value, double right_value, int left_cnt, int right_cnt,
double left_value, double right_value, int left_cnt, int right_cnt, float gain); double left_weight, double right_weight, float gain);
/*! /*!
* \brief Find leaf index of which record belongs by features * \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record * \param feature_values Feature value of this record
...@@ -383,10 +388,14 @@ class Tree { ...@@ -383,10 +388,14 @@ class Tree {
std::vector<int> leaf_parent_; std::vector<int> leaf_parent_;
/*! \brief Output of leaves */ /*! \brief Output of leaves */
std::vector<double> leaf_value_; std::vector<double> leaf_value_;
/*! \brief weight of leaves */
std::vector<double> leaf_weight_;
/*! \brief DataCount of leaves */ /*! \brief DataCount of leaves */
std::vector<int> leaf_count_; std::vector<int> leaf_count_;
/*! \brief Output of non-leaf nodes */ /*! \brief Output of non-leaf nodes */
std::vector<double> internal_value_; std::vector<double> internal_value_;
/*! \brief weight of non-leaf nodes */
std::vector<double> internal_weight_;
/*! \brief DataCount of non-leaf nodes */ /*! \brief DataCount of non-leaf nodes */
std::vector<int> internal_count_; std::vector<int> internal_count_;
/*! \brief Depth for leaves */ /*! \brief Depth for leaves */
...@@ -396,7 +405,8 @@ class Tree { ...@@ -396,7 +405,8 @@ class Tree {
}; };
inline void Tree::Split(int leaf, int feature, int real_feature, inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, int left_cnt, int right_cnt, float gain) { double left_value, double right_value, int left_cnt, int right_cnt,
double left_weight, double right_weight, float gain) {
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
// update parent info // update parent info
int parent = leaf_parent_[leaf]; int parent = leaf_parent_[leaf];
...@@ -420,11 +430,14 @@ inline void Tree::Split(int leaf, int feature, int real_feature, ...@@ -420,11 +430,14 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
leaf_parent_[leaf] = new_node_idx; leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change // save current leaf value to internal node before change
internal_weight_[new_node_idx] = leaf_weight_[leaf];
internal_value_[new_node_idx] = leaf_value_[leaf]; internal_value_[new_node_idx] = leaf_value_[leaf];
internal_count_[new_node_idx] = left_cnt + right_cnt; internal_count_[new_node_idx] = left_cnt + right_cnt;
leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value; leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
leaf_weight_[leaf] = left_weight;
leaf_count_[leaf] = left_cnt; leaf_count_[leaf] = left_cnt;
leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value; leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
leaf_weight_[num_leaves_] = right_weight;
leaf_count_[num_leaves_] = right_cnt; leaf_count_[num_leaves_] = right_cnt;
// update leaf depth // update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1; leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
......
...@@ -390,7 +390,7 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): ...@@ -390,7 +390,7 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
label = 'split_feature_index: {0}'.format(root['split_feature']) label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision)) label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info: for info in show_info:
if info in {'split_gain', 'internal_value'}: if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision)) label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count': elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info]) label += r'\n{0}: {1}'.format(info, root[info])
...@@ -409,6 +409,8 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): ...@@ -409,6 +409,8 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision)) label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info: if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count']) label += r'\nleaf_count: {0}'.format(root['leaf_count'])
if 'leaf_weight' in show_info:
label += r'\nleaf_weight: {0}'.format(_float2str(root['leaf_weight'], precision))
graph.node(name, label=label) graph.node(name, label=label)
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision)
...@@ -438,7 +440,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -438,7 +440,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
The index of a target tree to convert. The index of a target tree to convert.
show_info : list of strings or None, optional (default=None) show_info : list of strings or None, optional (default=None)
What information should be shown in nodes. What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None) precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision. Used to restrict the display of floating point values to a certain precision.
**kwargs **kwargs
...@@ -515,7 +518,8 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -515,7 +518,8 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
Figure size. Figure size.
show_info : list of strings or None, optional (default=None) show_info : list of strings or None, optional (default=None)
What information should be shown in nodes. What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None) precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision. Used to restrict the display of floating point values to a certain precision.
**kwargs **kwargs
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace LightGBM { namespace LightGBM {
const std::string kModelVersion = "v2"; const std::string kModelVersion = "v3";
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const { std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
std::stringstream str_buf; std::stringstream str_buf;
......
...@@ -26,14 +26,17 @@ Tree::Tree(int max_leaves) ...@@ -26,14 +26,17 @@ Tree::Tree(int max_leaves)
split_gain_.resize(max_leaves_ - 1); split_gain_.resize(max_leaves_ - 1);
leaf_parent_.resize(max_leaves_); leaf_parent_.resize(max_leaves_);
leaf_value_.resize(max_leaves_); leaf_value_.resize(max_leaves_);
leaf_weight_.resize(max_leaves_);
leaf_count_.resize(max_leaves_); leaf_count_.resize(max_leaves_);
internal_value_.resize(max_leaves_ - 1); internal_value_.resize(max_leaves_ - 1);
internal_weight_.resize(max_leaves_ - 1);
internal_count_.resize(max_leaves_ - 1); internal_count_.resize(max_leaves_ - 1);
leaf_depth_.resize(max_leaves_); leaf_depth_.resize(max_leaves_);
// root is in the depth 0 // root is in the depth 0
leaf_depth_[0] = 0; leaf_depth_[0] = 0;
num_leaves_ = 1; num_leaves_ = 1;
leaf_value_[0] = 0.0f; leaf_value_[0] = 0.0f;
leaf_weight_[0] = 0.0f;
leaf_parent_[0] = -1; leaf_parent_[0] = -1;
shrinkage_ = 1.0f; shrinkage_ = 1.0f;
num_cat_ = 0; num_cat_ = 0;
...@@ -47,8 +50,8 @@ Tree::~Tree() { ...@@ -47,8 +50,8 @@ Tree::~Tree() {
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value, double threshold_double, double left_value, double right_value,
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) { int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain); Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0; decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask); SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
...@@ -68,8 +71,8 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, ...@@ -68,8 +71,8 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin, int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value, const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) { data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain); Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0; decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask); SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
...@@ -221,10 +224,14 @@ std::string Tree::ToString() const { ...@@ -221,10 +224,14 @@ std::string Tree::ToString() const {
<< Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n'; << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
str_buf << "leaf_value=" str_buf << "leaf_value="
<< Common::ArrayToString(leaf_value_, num_leaves_) << '\n'; << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
str_buf << "leaf_weight="
<< Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
str_buf << "leaf_count=" str_buf << "leaf_count="
<< Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n'; << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
str_buf << "internal_value=" str_buf << "internal_value="
<< Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n'; << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
str_buf << "internal_weight="
<< Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
str_buf << "internal_count=" str_buf << "internal_count="
<< Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n'; << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
if (num_cat_ > 0) { if (num_cat_ > 0) {
...@@ -294,6 +301,7 @@ std::string Tree::NodeToJSON(int index) const { ...@@ -294,6 +301,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"missing_type\":\"NaN\"," << '\n'; str_buf << "\"missing_type\":\"NaN\"," << '\n';
} }
str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n'; str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n'; str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n'; str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n'; str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
...@@ -304,6 +312,7 @@ std::string Tree::NodeToJSON(int index) const { ...@@ -304,6 +312,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "{" << '\n'; str_buf << "{" << '\n';
str_buf << "\"leaf_index\":" << index << "," << '\n'; str_buf << "\"leaf_index\":" << index << "," << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n'; str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n'; str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
str_buf << "}"; str_buf << "}";
} }
...@@ -472,7 +481,7 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const { ...@@ -472,7 +481,7 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
Tree::Tree(const char* str, size_t* used_len) { Tree::Tree(const char* str, size_t* used_len) {
auto p = str; auto p = str;
std::unordered_map<std::string, std::string> key_vals; std::unordered_map<std::string, std::string> key_vals;
const int max_num_line = 15; const int max_num_line = 17;
int read_line = 0; int read_line = 0;
while (read_line < max_num_line) { while (read_line < max_num_line) {
if (*p == '\r' || *p == '\n') break; if (*p == '\r' || *p == '\n') break;
...@@ -557,6 +566,20 @@ Tree::Tree(const char* str, size_t* used_len) { ...@@ -557,6 +566,20 @@ Tree::Tree(const char* str, size_t* used_len) {
internal_value_.resize(num_leaves_ - 1); internal_value_.resize(num_leaves_ - 1);
} }
if (key_vals.count("internal_weight")) {
internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
}
else {
internal_weight_.resize(num_leaves_ - 1);
}
if (key_vals.count("leaf_weight")) {
leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
}
else {
leaf_weight_.resize(num_leaves_);
}
if (key_vals.count("leaf_count")) { if (key_vals.count("leaf_count")) {
leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_); leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
} else { } else {
......
...@@ -684,6 +684,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int* ...@@ -684,6 +684,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int*
static_cast<double>(current_split_info.right_output), static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count), static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count), static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain), static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(), train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
current_split_info.default_left); current_split_info.default_left);
...@@ -711,6 +713,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int* ...@@ -711,6 +713,8 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, Json& forced_split_json, int*
static_cast<double>(current_split_info.right_output), static_cast<double>(current_split_info.right_output),
static_cast<data_size_t>(current_split_info.left_count), static_cast<data_size_t>(current_split_info.left_count),
static_cast<data_size_t>(current_split_info.right_count), static_cast<data_size_t>(current_split_info.right_count),
static_cast<double>(current_split_info.left_sum_hessian),
static_cast<double>(current_split_info.right_sum_hessian),
static_cast<float>(current_split_info.gain), static_cast<float>(current_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type()); train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(current_leaf, train_data_, inner_feature_index, data_partition_->Split(current_leaf, train_data_, inner_feature_index,
...@@ -792,6 +796,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -792,6 +796,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output), static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count), static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count), static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain), static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type(), train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
best_split_info.default_left); best_split_info.default_left);
...@@ -815,6 +821,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -815,6 +821,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
static_cast<double>(best_split_info.right_output), static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count), static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count), static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.left_sum_hessian),
static_cast<double>(best_split_info.right_sum_hessian),
static_cast<float>(best_split_info.gain), static_cast<float>(best_split_info.gain),
train_data_->FeatureBinMapper(inner_feature_index)->missing_type()); train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
data_partition_->Split(best_leaf, train_data_, inner_feature_index, data_partition_->Split(best_leaf, train_data_, inner_feature_index,
......
...@@ -120,7 +120,7 @@ class TestBasic(unittest.TestCase): ...@@ -120,7 +120,7 @@ class TestBasic(unittest.TestCase):
self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83) self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)
graph = lgb.create_tree_digraph(gbm, tree_index=3, graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value'], show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'}) name='Tree4', node_attr={'color': 'red'})
graph.render(view=False) graph.render(view=False)
self.assertIsInstance(graph, graphviz.Digraph) self.assertIsInstance(graph, graphviz.Digraph)
...@@ -137,8 +137,10 @@ class TestBasic(unittest.TestCase): ...@@ -137,8 +137,10 @@ class TestBasic(unittest.TestCase):
self.assertIn('leaf_index', graph_body) self.assertIn('leaf_index', graph_body)
self.assertIn('split_gain', graph_body) self.assertIn('split_gain', graph_body)
self.assertIn('internal_value', graph_body) self.assertIn('internal_value', graph_body)
self.assertIn('internal_weight', graph_body)
self.assertNotIn('internal_count', graph_body) self.assertNotIn('internal_count', graph_body)
self.assertNotIn('leaf_count', graph_body) self.assertNotIn('leaf_count', graph_body)
self.assertNotIn('leaf_weight', graph_body)
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed') @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self): def test_plot_metrics(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment