plot_example.py 1.83 KB
Newer Older
1
2
3
4
# coding: utf-8
import lightgbm as lgb
import pandas as pd

5
if lgb.compat.MATPLOTLIB_INSTALLED:
6
    import matplotlib.pyplot as plt
7
else:
8
9
    raise ImportError('You need to install matplotlib for plot_example.py.')

10
print('Loading data...')
11
12
# load or create your dataset
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
13
df_test = pd.read_csv('../regression/regression.test', header=None, sep='\t')
14

15
16
17
18
y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
19
20
21

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
22
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)
23
24
25

# specify your configurations as a dict
params = {
wxchan's avatar
wxchan committed
26
    'num_leaves': 5,
27
    'metric': ('l1', 'l2'),
28
29
30
    'verbose': 0
}

31
32
evals_result = {}  # to record eval results for plotting

33
print('Starting training...')
34
35
36
# train
gbm = lgb.train(params,
                lgb_train,
wxchan's avatar
wxchan committed
37
                num_boost_round=100,
38
                valid_sets=[lgb_train, lgb_test],
39
                feature_name=['f' + str(i + 1) for i in range(X_train.shape[-1])],
40
                categorical_feature=[21],
41
42
43
                evals_result=evals_result,
                verbose_eval=10)

44
print('Plotting metrics recorded during training...')
45
46
ax = lgb.plot_metric(evals_result, metric='l1')
plt.show()
47

48
print('Plotting feature importances...')
49
50
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
wxchan's avatar
wxchan committed
51

52
53
54
55
print('Plotting split value histogram...')
ax = lgb.plot_split_value_histogram(gbm, feature='f26', bins='auto')
plt.show()

56
57
print('Plotting 54th tree...')  # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=53, figsize=(15, 15), show_info=['split_gain'])
wxchan's avatar
wxchan committed
58
plt.show()
59

60
61
print('Plotting 54th tree with graphviz...')
graph = lgb.create_tree_digraph(gbm, tree_index=53, name='Tree54')
62
graph.render(view=True)