Unverified Commit f8bab7fc authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors in `plotting.py` (#4838)

* [python-package] fix mypy errors in plotting.py

* empty commit
parent 8f4126d6
...@@ -236,12 +236,12 @@ def plot_split_value_histogram( ...@@ -236,12 +236,12 @@ def plot_split_value_histogram(
elif not isinstance(booster, Booster): elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.') raise TypeError('booster must be Booster or LGBMModel.')
hist, bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False) hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
if np.count_nonzero(hist) == 0: if np.count_nonzero(hist) == 0:
raise ValueError('Cannot plot split value histogram, ' raise ValueError('Cannot plot split value histogram, '
f'because feature {feature} was not used in splitting') f'because feature {feature} was not used in splitting')
width = width_coef * (bins[1] - bins[0]) width = width_coef * (split_bins[1] - split_bins[0])
centred = (bins[:-1] + bins[1:]) / 2 centred = (split_bins[:-1] + split_bins[1:]) / 2
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
...@@ -253,8 +253,8 @@ def plot_split_value_histogram( ...@@ -253,8 +253,8 @@ def plot_split_value_histogram(
if xlim is not None: if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim') _check_not_tuple_of_2_elements(xlim, 'xlim')
else: else:
range_result = bins[-1] - bins[0] range_result = split_bins[-1] - split_bins[0]
xlim = (bins[0] - range_result * 0.2, bins[-1] + range_result * 0.2) xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2)
ax.set_xlim(xlim) ax.set_xlim(xlim)
ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True))
...@@ -358,13 +358,13 @@ def plot_metric( ...@@ -358,13 +358,13 @@ def plot_metric(
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
if dataset_names is None: if dataset_names is None:
dataset_names = iter(eval_results.keys()) dataset_names_iter = iter(eval_results.keys())
elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names: elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
raise ValueError('dataset_names should be iterable and cannot be empty') raise ValueError('dataset_names should be iterable and cannot be empty')
else: else:
dataset_names = iter(dataset_names) dataset_names_iter = iter(dataset_names)
name = next(dataset_names) # take one as sample name = next(dataset_names_iter) # take one as sample
metrics_for_one = eval_results[name] metrics_for_one = eval_results[name]
num_metric = len(metrics_for_one) num_metric = len(metrics_for_one)
if metric is None: if metric is None:
...@@ -381,7 +381,7 @@ def plot_metric( ...@@ -381,7 +381,7 @@ def plot_metric(
x_ = range(num_iteration) x_ = range(num_iteration)
ax.plot(x_, results, label=name) ax.plot(x_, results, label=name)
for name in dataset_names: for name in dataset_names_iter:
metrics_for_one = eval_results[name] metrics_for_one = eval_results[name]
results = metrics_for_one[metric] results = metrics_for_one[metric]
max_result = max(max(results), max_result) max_result = max(max(results), max_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment