Commit e9275fb9 authored by Jayvee He's avatar Jayvee He Committed by Guolin Ke
Browse files

Add Python examples on how to load models (#371)

* Update Python-API.md

* for a better jump in page

A space is needed between `#` and the headers content according to Github's markdown format [guideline](https://guides.github.com/features/mastering-markdown/)

After adding the spaces, we can jump to the exact position in page by click the link.

* fixed something mentioned by @wxchan

* Update Python-API.md

* Add examples on how to use the saved model

* move 'load model' examples to advanced_example
parent 21861cd4
......@@ -30,6 +30,8 @@ Examples including:
- [advanced_example.py](https://github.com/Microsoft/LightGBM/blob/master/examples/python-guide/advanced_example.py)
- Set feature names
- Directly use categorical features without one-hot encoding
- Load model to predict
- Dump and load model with pickle
- Load model file to continue training
- Change learning rates during training
- Self-defined objective function
......
......@@ -3,6 +3,12 @@
import lightgbm as lgb
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error
try:
import cPickle as pickle
except:
import pickle
# load or create your dataset
print('Load data...')
......@@ -57,6 +63,25 @@ print('7th feature name is:', repr(lgb_train.feature_name[6]))
# save model to file
gbm.save_model('model.txt')
# load model to predict
print('Load model to predict')
bst = lgb.Booster(model_file='model.txt')
# can only predict with the best iteration (or the saving iteration)
y_pred = bst.predict(X_test)
# eval with loaded model
print('The rmse of loaded model\'s prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)
# dump model with pickle
with open('model.pkl', 'wb') as fout:
pickle.dump(gbm, fout)
# load model with pickle to predict
with open('model.pkl', 'rb') as fin:
pkl_bst = pickle.load(fin)
# can predict with any iteration when loaded in pickle way
y_pred = pkl_bst.predict(X_test, num_iteration=7)
# eval with loaded model
print('The rmse of pickled model\'s prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)
# continue training
# init_model accepts:
# 1. model file name
......
......@@ -5,6 +5,7 @@ import lightgbm as lgb
import pandas as pd
from sklearn.metrics import mean_squared_error
# load or create your dataset
print('Load data...')
df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
......@@ -58,6 +59,7 @@ model_json = gbm.dump_model()
with open('model.json', 'w+') as f:
json.dump(model_json, f, indent=4)
print('Feature names:', gbm.feature_name())
print('Calculate feature importances...')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment