Commit 0312ecde authored by Nikita Titov's avatar Nikita Titov Committed by Qiwei Ye
Browse files

Added precision param to plot_importance function (#1777)

parent af3c4f89
......@@ -21,11 +21,18 @@ def _check_not_tuple_of_2_elements(obj, obj_name='obj'):
raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
def _float2str(value, precision=None):
return ("{0:.{1}f}".format(value, precision)
if precision is not None and not isinstance(value, string_type)
else str(value))
def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None,
ignore_zero=True, figsize=None, grid=True, **kwargs):
ignore_zero=True, figsize=None, grid=True,
precision=None, **kwargs):
"""Plot model's feature importances.
Parameters
......@@ -63,6 +70,8 @@ def plot_importance(booster, ax=None, height=0.2,
Figure size.
grid : bool, optional (default=True)
Whether to add a grid for axes.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Other parameters passed to ``ax.barh()``.
......@@ -103,7 +112,9 @@ def plot_importance(booster, ax=None, height=0.2,
ax.barh(ylocs, values, align='center', height=height, **kwargs)
for x, y in zip_(values, ylocs):
ax.text(x + 1, y, x, va='center')
ax.text(x + 1, y,
_float2str(x, precision) if importance_type == 'gain' else x,
va='center')
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
......@@ -265,10 +276,6 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
else:
raise ImportError('You must install graphviz to plot tree.')
def float2str(value, precision=None):
return "{0:.{1}f}".format(value, precision) \
if precision is not None and not isinstance(value, string_type) else str(value)
def add(root, parent=None, decision=None):
"""Recursively add node or edge."""
if 'split_index' in root: # non-leaf
......@@ -277,10 +284,10 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else:
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(float2str(root['threshold'], precision))
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value'}:
label += r'\n{0}: {1}'.format(info, float2str(root[info], precision))
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label)
......@@ -295,7 +302,7 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
else: # leaf
name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index: {0}'.format(root['leaf_index'])
label += r'\nleaf_value: {0}'.format(float2str(root['leaf_value'], precision))
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
graph.node(name, label=label)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment