plot_example.py 1.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# coding: utf-8
# pylint: disable = invalid-name, C0111
import lightgbm as lgb
import pandas as pd

try:
    import matplotlib.pyplot as plt
except ImportError:
    raise ImportError('You need to install matplotlib for plot_example.py.')

# load or create your dataset
print('Load data...')
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
14
df_test = pd.read_csv('../regression/regression.test', header=None, sep='\t')
15

wxchan's avatar
wxchan committed
16
17
18
19
y_train = df_train[0].values
y_test = df_test[0].values
X_train = df_train.drop(0, axis=1).values
X_test = df_test.drop(0, axis=1).values
20
21
22

# create dataset for lightgbm
lgb_train = lgb.Dataset(X_train, y_train)
23
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)
24
25
26

# specify your configurations as a dict
params = {
wxchan's avatar
wxchan committed
27
    'num_leaves': 5,
28
    'metric': ('l1', 'l2'),
29
30
31
    'verbose': 0
}

32
33
evals_result = {}  # to record eval results for plotting

34
35
36
37
print('Start training...')
# train
gbm = lgb.train(params,
                lgb_train,
wxchan's avatar
wxchan committed
38
                num_boost_round=100,
39
                valid_sets=[lgb_train, lgb_test],
wxchan's avatar
wxchan committed
40
                feature_name=['f' + str(i + 1) for i in range(28)],
41
                categorical_feature=[21],
42
43
44
45
46
47
                evals_result=evals_result,
                verbose_eval=10)

print('Plot metrics during training...')
ax = lgb.plot_metric(evals_result, metric='l1')
plt.show()
48
49
50
51

print('Plot feature importances...')
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()
wxchan's avatar
wxchan committed
52

53
print('Plot 84th tree...')  # one tree use categorical feature to split
54
ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain'])
wxchan's avatar
wxchan committed
55
plt.show()
56
57
58
59

print('Plot 84th tree with graphviz...')
graph = lgb.create_tree_digraph(gbm, tree_index=83, name='Tree84')
graph.render(view=True)