"include/git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "b961ac644ff0ec22e3a6edc5925bd9b0f2eff0df"
Commit f52be9be authored by CharlesAuguste's avatar CharlesAuguste Committed by Nikita Titov
Browse files

[python] Improved python tree plots (#2304)

* Some basic changes to the plot of the trees to make them readable.

* Squeezed the information in the nodes.

* Added colouring when a dictionnary mapping the features to the constraints is passed.

* Fix spaces.

* Added data percentage as an option in the nodes.

* Squeezed the information in the leaves.

* Important information is now in bold.

* Added a legend for the color of monotone splits.

* Changed "split_gain" to "gain" and "internal_value" to "value".

* Sqeezed leaves a bit more.

* Changed description in the legend.

* Revert "Sqeezed leaves a bit more."

This reverts commit dd8bf14a3ba604b0dfae3b7bb1c64b6784d15e03.

* Increased the readability for the gain.

* Tidied up the legend.

* Added the data percentage in the leaves.

* Added the monotone constraints to the dumped model.

* Monotone constraints are now specified automatically when plotting trees.

* Raise an exception instead of the bug that was here before.

* Removed operators on the branches for a clearer design.

* Small cleaning of the code.

* Setting a monotone constraint on a categorical feature now returns an exception instead of doing nothing.

* Fix bug when monotone constraints are empty.

* Fix another bug when monotone constraints are empty.

* Variable name change.

* Added is / isn't on every edge of the trees.

* Fix test "tree_create_digraph".

* Add new test for plotting trees with monotone constraints.

* Typo.

* Update documentation of categorical features.

* Typo.

* Information in nodes more explicit.

* Used regular strings instead of raw strings.

* Small refactoring.

* Some cleaning.

* Added future statement.

* Changed output for consistency.

* Updated documentation.

* Added comments for colors.

* Changed text on edges for more clarity.

* Small refactoring.

* Modified text in leaves for consistency with nodes.

* Updated default values and documentaton for consistency.

* Replaced CHECK with Log::Fatal for user-friendliness.

* Updated tests.

* Typo.

* Simplify imports.

* Swapped count and weight to improve readibility of the leaves in the plotted trees.

* Thresholds in bold.

* Made information in nodes written in a specific order.

* Added information to clarify legend.

* Code cleaning.
parent b6d4ad83
...@@ -659,6 +659,8 @@ IO Parameters ...@@ -659,6 +659,8 @@ IO Parameters
- **Note**: all negative values will be treated as **missing values** - **Note**: all negative values will be treated as **missing values**
- **Note**: the output cannot be monotonically constrained with respect to a categorical feature
- ``predict_raw_score`` :raw-html:`<a id="predict_raw_score" title="Permalink to this parameter" href="#predict_raw_score">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool, aliases: ``is_predict_raw_score``, ``predict_rawscore``, ``raw_score`` - ``predict_raw_score`` :raw-html:`<a id="predict_raw_score" title="Permalink to this parameter" href="#predict_raw_score">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool, aliases: ``is_predict_raw_score``, ``predict_rawscore``, ``raw_score``
- used only in ``prediction`` task - used only in ``prediction`` task
......
...@@ -609,6 +609,7 @@ struct Config { ...@@ -609,6 +609,7 @@ struct Config {
// desc = **Note**: all values should be less than ``Int32.MaxValue`` (2147483647) // desc = **Note**: all values should be less than ``Int32.MaxValue`` (2147483647)
// desc = **Note**: using large values could be memory consuming. Tree decision rule works best when categorical features are presented by consecutive integers starting from zero // desc = **Note**: using large values could be memory consuming. Tree decision rule works best when categorical features are presented by consecutive integers starting from zero
// desc = **Note**: all negative values will be treated as **missing values** // desc = **Note**: all negative values will be treated as **missing values**
// desc = **Note**: the output cannot be monotonically constrained with respect to a categorical feature
std::string categorical_feature = ""; std::string categorical_feature = "";
// alias = is_predict_raw_score, predict_rawscore, raw_score // alias = is_predict_raw_score, predict_rawscore, raw_score
......
...@@ -162,6 +162,31 @@ inline static const char* Atoi(const char* p, T* out) { ...@@ -162,6 +162,31 @@ inline static const char* Atoi(const char* p, T* out) {
return p; return p;
} }
template <typename T>
inline void SplitToIntLike(const char *c_str, char delimiter,
std::vector<T> &ret) {
CHECK(ret.empty());
std::string str(c_str);
size_t i = 0;
size_t pos = 0;
while (pos < str.length()) {
if (str[pos] == delimiter) {
if (i < pos) {
ret.push_back({});
Atoi(str.substr(i, pos - i).c_str(), &ret.back());
}
++pos;
i = pos;
} else {
++pos;
}
}
if (i < pos) {
ret.push_back({});
Atoi(str.substr(i).c_str(), &ret.back());
}
}
template<typename T> template<typename T>
inline static double Pow(T base, int power) { inline static double Pow(T base, int power) {
if (power < 0) { if (power < 0) {
...@@ -551,6 +576,21 @@ inline static std::string Join(const std::vector<T>& strs, const char* delimiter ...@@ -551,6 +576,21 @@ inline static std::string Join(const std::vector<T>& strs, const char* delimiter
return str_buf.str(); return str_buf.str();
} }
template<>
inline std::string Join<int8_t>(const std::vector<int8_t>& strs, const char* delimiter) {
if (strs.empty()) {
return std::string("");
}
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << static_cast<int16_t>(strs[0]);
for (size_t i = 1; i < strs.size(); ++i) {
str_buf << delimiter;
str_buf << static_cast<int16_t>(strs[i]);
}
return str_buf.str();
}
template<typename T> template<typename T>
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) { inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) {
if (end - start <= 0) { if (end - start <= 0) {
......
...@@ -698,6 +698,7 @@ class Dataset(object): ...@@ -698,6 +698,7 @@ class Dataset(object):
All values in categorical features should be less than int32 max value (2147483647). All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero. Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values. All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
params : dict or None, optional (default=None) params : dict or None, optional (default=None)
Other parameters for Dataset. Other parameters for Dataset.
free_raw_data : bool, optional (default=True) free_raw_data : bool, optional (default=True)
......
...@@ -88,6 +88,7 @@ def train(params, train_set, num_boost_round=100, ...@@ -88,6 +88,7 @@ def train(params, train_set, num_boost_round=100,
All values in categorical features should be less than int32 max value (2147483647). All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero. Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values. All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None) early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. The model will train until the validation score stops improving. Activates early stopping. The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s) Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
...@@ -451,6 +452,7 @@ def cv(params, train_set, num_boost_round=100, ...@@ -451,6 +452,7 @@ def cv(params, train_set, num_boost_round=100,
All values in categorical features should be less than int32 max value (2147483647). All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero. Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values. All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None) early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. Activates early stopping.
CV score needs to improve at least every ``early_stopping_rounds`` round(s) CV score needs to improve at least every ``early_stopping_rounds`` round(s)
......
# coding: utf-8 # coding: utf-8
# pylint: disable = C0103 # pylint: disable = C0103
"""Plotting library.""" """Plotting library."""
from __future__ import absolute_import from __future__ import absolute_import, division
import warnings import warnings
from copy import deepcopy from copy import deepcopy
...@@ -369,7 +369,7 @@ def plot_metric(booster, metric=None, dataset_names=None, ...@@ -369,7 +369,7 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax return ax
def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): def _to_graphviz(tree_info, show_info, feature_names, precision=3, constraints=None, **kwargs):
"""Convert specified tree to graphviz instance. """Convert specified tree to graphviz instance.
See: See:
...@@ -380,48 +380,90 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): ...@@ -380,48 +380,90 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
else: else:
raise ImportError('You must install graphviz to plot tree.') raise ImportError('You must install graphviz to plot tree.')
def add(root, parent=None, decision=None): def add(root, total_count, parent=None, decision=None):
"""Recursively add node or edge.""" """Recursively add node or edge."""
if 'split_index' in root: # non-leaf if 'split_index' in root: # non-leaf
name = 'split{0}'.format(root['split_index']) l_dec = 'yes'
if feature_names is not None: r_dec = 'no'
label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else:
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label)
if root['decision_type'] == '<=': if root['decision_type'] == '<=':
l_dec, r_dec = '<=', '>' lte_symbol = "&#8804;"
operator = lte_symbol
elif root['decision_type'] == '==': elif root['decision_type'] == '==':
l_dec, r_dec = 'is', "isn't" operator = "="
else: else:
raise ValueError('Invalid decision type in tree model.') raise ValueError('Invalid decision type in tree model.')
add(root['left_child'], name, l_dec) name = 'split{0}'.format(root['split_index'])
add(root['right_child'], name, r_dec) if feature_names is not None:
label = '<B>{0}</B> {1} '.format(feature_names[root['split_feature']], operator)
else:
label = 'feature <B>{0}</B> {1} '.format(root['split_feature'], operator)
label += '<B>{0}</B>'.format(_float2str(root['threshold'], precision))
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]:
if info in show_info:
output = info.split('_')[-1]
if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += '<br/>{0} {1}'.format(_float2str(root[info], precision), output)
elif info == 'internal_count':
label += '<br/>{0}: {1}'.format(output, root[info])
elif info == "data_percentage":
label += '<br/>{0}% of data'.format(_float2str(root['internal_count'] / total_count * 100, 2))
fillcolor = "white"
style = ""
if constraints:
if constraints[root['split_feature']] == 1:
fillcolor = "#ddffdd" # light green
if constraints[root['split_feature']] == -1:
fillcolor = "#ffdddd" # light red
style = "filled"
label = "<" + label + ">"
graph.node(name, label=label, shape="rectangle", style=style, fillcolor=fillcolor)
add(root['left_child'], total_count, name, l_dec)
add(root['right_child'], total_count, name, r_dec)
else: # leaf else: # leaf
name = 'leaf{0}'.format(root['leaf_index']) name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index: {0}'.format(root['leaf_index']) label = 'leaf {0}: '.format(root['leaf_index'])
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision)) label += '<B>{0}</B>'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
if 'leaf_weight' in show_info: if 'leaf_weight' in show_info:
label += r'\nleaf_weight: {0}'.format(_float2str(root['leaf_weight'], precision)) label += '<br/>{0} weight'.format(_float2str(root['leaf_weight'], precision))
if 'leaf_count' in show_info:
label += '<br/>count: {0}'.format(root['leaf_count'])
if "data_percentage" in show_info:
label += '<br/>{0}% of data'.format(_float2str(root['leaf_count'] / total_count * 100, 2))
label = "<" + label + ">"
graph.node(name, label=label) graph.node(name, label=label)
if parent is not None: if parent is not None:
graph.edge(parent, name, decision) graph.edge(parent, name, decision)
graph = Digraph(**kwargs) graph = Digraph(**kwargs)
add(tree_info['tree_structure']) graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir="LR")
if "internal_count" in tree_info['tree_structure']:
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"])
else:
raise Exception("Cannnot plot trees with no split")
if constraints:
# "#ddffdd" is light green, "#ffdddd" is light red
legend = """<
<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">
<TR>
<TD COLSPAN="2"><B>Monotone constraints</B></TD>
</TR>
<TR>
<TD>Increasing</TD>
<TD BGCOLOR="#ddffdd"></TD>
</TR>
<TR>
<TD>Decreasing</TD>
<TD BGCOLOR="#ffdddd"></TD>
</TR>
</TABLE>
>"""
graph.node("legend", label=legend, shape="rectangle", color="white")
return graph return graph
def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
old_name=None, old_comment=None, old_filename=None, old_directory=None, old_name=None, old_comment=None, old_filename=None, old_directory=None,
old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None, old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None,
old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs): old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs):
...@@ -441,8 +483,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -441,8 +483,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
show_info : list of strings or None, optional (default=None) show_info : list of strings or None, optional (default=None)
What information should be shown in nodes. What information should be shown in nodes.
Possible values of list items: Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'. 'split_gain', 'internal_value', 'internal_count', 'internal_weight',
precision : int or None, optional (default=None) 'leaf_count', 'leaf_weight', 'data_percentage'.
precision : int or None, optional (default=3)
Used to restrict the display of floating point values to a certain precision. Used to restrict the display of floating point values to a certain precision.
**kwargs **kwargs
Other parameters passed to ``Digraph`` constructor. Other parameters passed to ``Digraph`` constructor.
...@@ -482,6 +525,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -482,6 +525,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
else: else:
feature_names = None feature_names = None
monotone_constraints = model.get('monotone_constraints', None)
if tree_index < len(tree_infos): if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index] tree_info = tree_infos[tree_index]
else: else:
...@@ -490,14 +535,14 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, ...@@ -490,14 +535,14 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
if show_info is None: if show_info is None:
show_info = [] show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names, precision, **kwargs) graph = _to_graphviz(tree_info, show_info, feature_names, precision, monotone_constraints, **kwargs)
return graph return graph
def plot_tree(booster, ax=None, tree_index=0, figsize=None, def plot_tree(booster, ax=None, tree_index=0, figsize=None,
old_graph_attr=None, old_node_attr=None, old_edge_attr=None, old_graph_attr=None, old_node_attr=None, old_edge_attr=None,
show_info=None, precision=None, **kwargs): show_info=None, precision=3, **kwargs):
"""Plot specified tree. """Plot specified tree.
Note Note
...@@ -519,8 +564,9 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None, ...@@ -519,8 +564,9 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
show_info : list of strings or None, optional (default=None) show_info : list of strings or None, optional (default=None)
What information should be shown in nodes. What information should be shown in nodes.
Possible values of list items: Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'. 'split_gain', 'internal_value', 'internal_count', 'internal_weight',
precision : int or None, optional (default=None) 'leaf_count', 'leaf_weight', 'data_percentage'.
precision : int or None, optional (default=3)
Used to restrict the display of floating point values to a certain precision. Used to restrict the display of floating point values to a certain precision.
**kwargs **kwargs
Other parameters passed to ``Digraph`` constructor. Other parameters passed to ``Digraph`` constructor.
......
...@@ -437,6 +437,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -437,6 +437,7 @@ class LGBMModel(_LGBMModelBase):
All values in categorical features should be less than int32 max value (2147483647). All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero. Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values. All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
callbacks : list of callback functions or None, optional (default=None) callbacks : list of callback functions or None, optional (default=None)
List of callback functions that are applied at each iteration. List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information. See Callbacks in Python API for more information.
......
...@@ -103,6 +103,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective ...@@ -103,6 +103,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
// get feature names // get feature names
feature_names_ = train_data_->feature_names(); feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos(); feature_infos_ = train_data_->feature_infos();
monotone_constraints_ = config->monotone_constraints;
// if need bagging, create buffer // if need bagging, create buffer
ResetBaggingConfig(config_.get(), true); ResetBaggingConfig(config_.get(), true);
......
...@@ -504,6 +504,7 @@ class GBDT : public GBDTBase { ...@@ -504,6 +504,7 @@ class GBDT : public GBDTBase {
bool need_re_bagging_; bool need_re_bagging_;
bool balanced_bagging_; bool balanced_bagging_;
std::string loaded_parameter_; std::string loaded_parameter_;
std::vector<int8_t> monotone_constraints_;
Json forced_splits_json_; Json forced_splits_json_;
}; };
......
...@@ -31,9 +31,11 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const { ...@@ -31,9 +31,11 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n"; str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n";
} }
str_buf << "\"feature_names\":[\"" str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"")
<< Common::Join(feature_names_, "\",\"") << "\"]," << "\"]," << '\n';
<< '\n';
str_buf << "\"monotone_constraints\":["
<< Common::Join(monotone_constraints_, ",") << "]," << '\n';
str_buf << "\"tree_info\":["; str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
...@@ -269,6 +271,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons ...@@ -269,6 +271,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons
ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n'; ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
if (monotone_constraints_.size() != 0) {
ss << "monotone_constraints=" << Common::Join(monotone_constraints_, " ")
<< '\n';
}
ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n'; ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
...@@ -364,6 +371,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { ...@@ -364,6 +371,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
} else if (strs.size() > 2) { } else if (strs.size() > 2) {
if (strs[0] == "feature_names") { if (strs[0] == "feature_names") {
key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names=")); key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
} else if (strs[0] == "monotone_constraints") {
key_vals[strs[0]] = cur_line.substr(std::strlen("monotone_constraints="));
} else { } else {
// Use first 128 chars to avoid exceed the message buffer. // Use first 128 chars to avoid exceed the message buffer.
Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str()); Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
...@@ -424,6 +433,15 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { ...@@ -424,6 +433,15 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
return false; return false;
} }
// get monotone_constraints
if (key_vals.count("monotone_constraints")) {
Common::SplitToIntLike(key_vals["monotone_constraints"].c_str(), ' ', monotone_constraints_);
if (monotone_constraints_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of monotone_constraints");
return false;
}
}
if (key_vals.count("feature_infos")) { if (key_vals.count("feature_infos")) {
feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' '); feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) { if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
......
...@@ -580,6 +580,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -580,6 +580,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
BinType bin_type = BinType::NumericalBin; BinType bin_type = BinType::NumericalBin;
if (categorical_features_.count(i)) { if (categorical_features_.count(i)) {
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
if (!feat_is_unconstrained) {
Log::Fatal("The output cannot be monotone with respect to categorical features");
}
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) { if (config_.max_bin_by_feature.empty()) {
......
...@@ -114,7 +114,8 @@ class TestBasic(unittest.TestCase): ...@@ -114,7 +114,8 @@ class TestBasic(unittest.TestCase):
@unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed') @unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed')
def test_create_tree_digraph(self): def test_create_tree_digraph(self):
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) constraints = [-1, 1] * int(self.X_train.shape[1] / 2)
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, monotone_constraints=constraints)
gbm.fit(self.X_train, self.y_train, verbose=False) gbm.fit(self.X_train, self.y_train, verbose=False)
self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83) self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)
...@@ -131,16 +132,14 @@ class TestBasic(unittest.TestCase): ...@@ -131,16 +132,14 @@ class TestBasic(unittest.TestCase):
self.assertEqual(len(graph.graph_attr), 0) self.assertEqual(len(graph.graph_attr), 0)
self.assertEqual(len(graph.edge_attr), 0) self.assertEqual(len(graph.edge_attr), 0)
graph_body = ''.join(graph.body) graph_body = ''.join(graph.body)
self.assertIn('threshold', graph_body) self.assertIn('leaf', graph_body)
self.assertIn('split_feature_name', graph_body) self.assertIn('gain', graph_body)
self.assertNotIn('split_feature_index', graph_body) self.assertIn('value', graph_body)
self.assertIn('leaf_index', graph_body) self.assertIn('weight', graph_body)
self.assertIn('split_gain', graph_body) self.assertIn('#ffdddd', graph_body)
self.assertIn('internal_value', graph_body) self.assertIn('#ddffdd', graph_body)
self.assertIn('internal_weight', graph_body) self.assertNotIn('data', graph_body)
self.assertNotIn('internal_count', graph_body) self.assertNotIn('count', graph_body)
self.assertNotIn('leaf_count', graph_body)
self.assertNotIn('leaf_weight', graph_body)
@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed') @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self): def test_plot_metrics(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment