Commit 0312ecde authored by Nikita Titov's avatar Nikita Titov Committed by Qiwei Ye
Browse files

Added precision param to plot_importance function (#1777)

parent af3c4f89
...@@ -21,11 +21,18 @@ def _check_not_tuple_of_2_elements(obj, obj_name='obj'): ...@@ -21,11 +21,18 @@ def _check_not_tuple_of_2_elements(obj, obj_name='obj'):
raise TypeError('%s must be a tuple of 2 elements.' % obj_name) raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
def _float2str(value, precision=None):
return ("{0:.{1}f}".format(value, precision)
if precision is not None and not isinstance(value, string_type)
else str(value))
def plot_importance(booster, ax=None, height=0.2, def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance', xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features', xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None, importance_type='split', max_num_features=None,
ignore_zero=True, figsize=None, grid=True, **kwargs): ignore_zero=True, figsize=None, grid=True,
precision=None, **kwargs):
"""Plot model's feature importances. """Plot model's feature importances.
Parameters Parameters
...@@ -63,6 +70,8 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -63,6 +70,8 @@ def plot_importance(booster, ax=None, height=0.2,
Figure size. Figure size.
grid : bool, optional (default=True) grid : bool, optional (default=True)
Whether to add a grid for axes. Whether to add a grid for axes.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs **kwargs
Other parameters passed to ``ax.barh()``. Other parameters passed to ``ax.barh()``.
...@@ -103,7 +112,9 @@ def plot_importance(booster, ax=None, height=0.2, ...@@ -103,7 +112,9 @@ def plot_importance(booster, ax=None, height=0.2,
ax.barh(ylocs, values, align='center', height=height, **kwargs) ax.barh(ylocs, values, align='center', height=height, **kwargs)
for x, y in zip_(values, ylocs): for x, y in zip_(values, ylocs):
ax.text(x + 1, y, x, va='center') ax.text(x + 1, y,
_float2str(x, precision) if importance_type == 'gain' else x,
va='center')
ax.set_yticks(ylocs) ax.set_yticks(ylocs)
ax.set_yticklabels(labels) ax.set_yticklabels(labels)
...@@ -265,10 +276,6 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): ...@@ -265,10 +276,6 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
else: else:
raise ImportError('You must install graphviz to plot tree.') raise ImportError('You must install graphviz to plot tree.')
def float2str(value, precision=None):
return "{0:.{1}f}".format(value, precision) \
if precision is not None and not isinstance(value, string_type) else str(value)
def add(root, parent=None, decision=None): def add(root, parent=None, decision=None):
"""Recursively add node or edge.""" """Recursively add node or edge."""
if 'split_index' in root: # non-leaf if 'split_index' in root: # non-leaf
...@@ -277,10 +284,10 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): ...@@ -277,10 +284,10 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']]) label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else: else:
label = 'split_feature_index: {0}'.format(root['split_feature']) label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(float2str(root['threshold'], precision)) label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info: for info in show_info:
if info in {'split_gain', 'internal_value'}: if info in {'split_gain', 'internal_value'}:
label += r'\n{0}: {1}'.format(info, float2str(root[info], precision)) label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count': elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info]) label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label) graph.node(name, label=label)
...@@ -295,7 +302,7 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): ...@@ -295,7 +302,7 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
else: # leaf else: # leaf
name = 'leaf{0}'.format(root['leaf_index']) name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index: {0}'.format(root['leaf_index']) label = 'leaf_index: {0}'.format(root['leaf_index'])
label += r'\nleaf_value: {0}'.format(float2str(root['leaf_value'], precision)) label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info: if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count']) label += r'\nleaf_count: {0}'.format(root['leaf_count'])
graph.node(name, label=label) graph.node(name, label=label)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment