plot_example.py 1.96 KB
Newer Older
1
# coding: utf-8
2
3
from pathlib import Path

4
5
import pandas as pd

6
7
import lightgbm as lgb

8
if lgb.compat.MATPLOTLIB_INSTALLED:
9
    import matplotlib.pyplot as plt
10
else:
11
    raise ImportError('You need to install matplotlib and restart your session for plot_example.py.')
12

13
print('Loading data...')
14
# load or create your dataset
15
16
17
regression_example_dir = Path(__file__).absolute().parents[1] / 'regression'
df_train = pd.read_csv(str(regression_example_dir / 'regression.train'), header=None, sep='\t')
df_test = pd.read_csv(str(regression_example_dir / 'regression.test'), header=None, sep='\t')
18

19
20
21
22
y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
23
24
25

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
26
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)
27
28
29

# specify your configurations as a dict
params = {
wxchan's avatar
wxchan committed
30
    'num_leaves': 5,
31
    'metric': ('l1', 'l2'),
32
33
34
    'verbose': 0
}

35
36
evals_result = {}  # to record eval results for plotting

37
print('Starting training...')
38
# train
39
40
41
42
43
44
45
46
47
48
49
50
gbm = lgb.train(
    params,
    lgb_train,
    num_boost_round=100,
    valid_sets=[lgb_train, lgb_test],
    feature_name=[f'f{i + 1}' for i in range(X_train.shape[-1])],
    categorical_feature=[21],
    callbacks=[
        lgb.log_evaluation(10),
        lgb.record_evaluation(evals_result)
    ]
)
51

52
print('Plotting metrics recorded during training...')
53
54
ax = lgb.plot_metric(evals_result, metric='l1')
plt.show()
55

56
print('Plotting feature importances...')
57
58
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
wxchan's avatar
wxchan committed
59

60
61
62
63
print('Plotting split value histogram...')
ax = lgb.plot_split_value_histogram(gbm, feature='f26', bins='auto')
plt.show()

64
65
print('Plotting 54th tree...')  # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=53, figsize=(15, 15), show_info=['split_gain'])
wxchan's avatar
wxchan committed
66
plt.show()
67

68
69
print('Plotting 54th tree with graphviz...')
graph = lgb.create_tree_digraph(gbm, tree_index=53, name='Tree54')
70
graph.render(view=True)