Commit 28972b86 authored by Guolin Ke's avatar Guolin Ke
Browse files

[python-package] fix tmp file access problem in windows

parent 7f778877
......@@ -136,7 +136,7 @@ public:
* \brief Dump model to json format string
* \return Json format string of model
*/
virtual std::string DumpModel() const = 0;
virtual std::string DumpModel(int num_iteration) const = 0;
/*!
* \brief Save model to file
......
......@@ -557,12 +557,14 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
/*!
* \brief dump model to json
* \param handle handle
* \param num_iteration, <= 0 means save all
* \param buffer_len string buffer length, if buffer_len < out_len, re-allocate buffer
* \param out_len actual output length
* \param out_str json format string of model, need to pre-allocate memory before call this
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
int64_t* out_len,
char* out_str);
......
......@@ -9,6 +9,7 @@ import sys
import ctypes
import json
from tempfile import NamedTemporaryFile
import os
import numpy as np
import scipy.sparse
......@@ -131,6 +132,22 @@ def param_dict_to_str(data):
% (key, type(val).__name__))
return ' '.join(pairs)
class _temp_file:
def __enter__(self):
with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
self.name = f.name
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if os.path.isfile(self.name):
os.remove(self.name)
def readlines(self):
with open(self.name, "r+") as f:
ret = f.readlines()
return ret
def writelines(self, lines):
with open(self.name, "w+") as f:
ret = f.writelines(lines)
"""marco definition of data type in c_api of LightGBM"""
C_API_DTYPE_FLOAT32 = 0
C_API_DTYPE_FLOAT64 = 1
......@@ -276,7 +293,7 @@ class _InnerPredictor(object):
if num_iteration > self.num_total_iteration:
num_iteration = self.num_total_iteration
if is_str(data):
with NamedTemporaryFile(mode='w+') as f:
with _temp_file() as f:
_safe_call(_LIB.LGBM_BoosterPredictForFile(
self.handle,
c_str(data),
......@@ -1336,7 +1353,7 @@ class Booster(object):
return self.__deepcopy__(None)
def __deepcopy__(self, _):
with NamedTemporaryFile(mode='w+') as f:
with _temp_file() as f:
self.save_model(f.name)
return Booster(model_file=f.name)
......@@ -1346,7 +1363,7 @@ class Booster(object):
this.pop('train_set', None)
this.pop('valid_sets', None)
if handle is not None:
with NamedTemporaryFile(mode='w+') as f:
with _temp_file() as f:
self.save_model(f.name)
this["handle"] = f.readlines()
return this
......@@ -1356,9 +1373,8 @@ class Booster(object):
if model is not None:
handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int64(0)
with NamedTemporaryFile(mode='w+') as f:
with _temp_file() as f:
f.writelines(model)
f.flush()
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(f.name),
ctypes.byref(out_num_iterations),
......@@ -1570,27 +1586,37 @@ class Booster(object):
filename : str
Filename to save
num_iteration: int
Number of iteration that want to save. < 0 means save all
Number of iteration that want to save. < 0 means save the best iteration(if have)
"""
if num_iteration <= 0:
num_iteration = self.best_iteration
_safe_call(_LIB.LGBM_BoosterSaveModel(
self.handle,
num_iteration,
c_str(filename)))
def dump_model(self):
def dump_model(self, num_iteration=-1):
"""
Dump model to json format
Parameters
----------
num_iteration: int
Number of iteration that want to dump. < 0 means dump to best iteration(if have)
Returns
-------
Json format of model
"""
if num_iteration <= 0:
num_iteration = self.best_iteration
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle,
num_iteration,
buffer_len,
ctypes.byref(tmp_out_len),
ptr_string_buffer))
......@@ -1601,6 +1627,7 @@ class Booster(object):
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle,
num_iteration,
actual_len,
ctypes.byref(tmp_out_len),
ptr_string_buffer))
......@@ -1616,7 +1643,7 @@ class Booster(object):
Data source for prediction
When data type is string, it represents the path of txt file
num_iteration : int
Used iteration for prediction
Used iteration for prediction, < 0 means predict for best iteration(if have)
raw_score : bool
True for predict raw score
pred_leaf : bool
......@@ -1631,6 +1658,8 @@ class Booster(object):
Prediction result
"""
predictor = _InnerPredictor(booster_handle=self.handle)
if num_iteration <= 0:
num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)
def _to_predictor(self):
......
......@@ -190,7 +190,7 @@ def train(params, train_set, num_boost_round=100,
if booster.attr('best_iteration') is not None:
booster.best_iteration = int(booster.attr('best_iteration')) + 1
else:
booster.best_iteration = num_boost_round
booster.best_iteration = -1
return booster
......
......@@ -245,7 +245,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
}
void GBDT::RollbackOneIter() {
if (iter_ == 0) { return; }
if (iter_ <= 0) { return; }
int cur_iter = iter_ + num_init_iteration_ - 1;
// reset score
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
......@@ -428,7 +428,7 @@ void GBDT::Boosting() {
GetGradients(GetTrainingScore(&num_score), gradients_.data(), hessians_.data());
}
std::string GBDT::DumpModel() const {
std::string GBDT::DumpModel(int num_iteration) const {
std::stringstream str_buf;
str_buf << "{";
......@@ -449,7 +449,11 @@ std::string GBDT::DumpModel() const {
<< std::endl;
str_buf << "\"tree_info\":[";
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_class_, num_used_model);
}
for (int i = 0; i < num_used_model; ++i) {
if (i > 0) {
str_buf << ",";
}
......@@ -491,13 +495,10 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
output_file << "feature_names=" << Common::Join(feature_names.get(), " ") << std::endl;
output_file << std::endl;
int num_used_model = 0;
if (num_iteration <= 0) {
num_used_model = static_cast<int>(models_.size());
} else {
num_used_model = num_iteration * num_class_;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_class_, num_used_model);
}
num_used_model = std::min(num_used_model, static_cast<int>(models_.size()));
// output tree models
for (int i = 0; i < num_used_model; ++i) {
output_file << "Tree=" << i << std::endl;
......
......@@ -148,7 +148,7 @@ public:
* \brief Dump model to json format string
* \return Json format string of model
*/
std::string DumpModel() const override;
std::string DumpModel(int num_iteration) const override;
/*!
* \brief Save model to file
......@@ -175,7 +175,6 @@ public:
*/
inline int LabelIdx() const override { return label_idx_; }
/*!
* \brief Get number of weak sub-models
* \return Number of weak sub-models
......@@ -192,13 +191,10 @@ public:
* \brief Set number of iterations for prediction
*/
inline void SetNumIterationForPred(int num_iteration) override {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
if (num_iteration > 0) {
num_iteration_for_pred_ = num_iteration;
} else {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_);
}
num_iteration_for_pred_ = std::min(num_iteration_for_pred_,
static_cast<int>(models_.size()) / num_class_);
}
inline double GetLeafValue(int tree_idx, int leaf_idx) const {
......
......@@ -181,8 +181,8 @@ public:
boosting_->SaveModelToFile(num_iteration, filename);
}
std::string DumpModel() {
return boosting_->DumpModel();
std::string DumpModel(int num_iteration) {
return boosting_->DumpModel(num_iteration);
}
double GetLeafValue(int tree_idx, int leaf_idx) const {
......@@ -581,8 +581,7 @@ DllExport int LGBM_BoosterCreateFromModelfile(
BoosterHandle* out) {
API_BEGIN();
auto ret = std::unique_ptr<Booster>(new Booster(filename));
*out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
/ ret->GetBoosting()->NumberOfClasses());
*out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->GetCurrentIteration());
*out = ret.release();
API_END();
}
......@@ -872,12 +871,13 @@ DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
}
DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel();
std::string model = ref_booster->DumpModel(num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str());
......
# coding: utf-8
# pylint: skip-file
import unittest, tempfile
import unittest, tempfile, os
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
......@@ -31,9 +31,11 @@ class TestBasic(unittest.TestCase):
bst.save_model("model.txt")
pred_from_matr = bst.predict(X_test)
with tempfile.NamedTemporaryFile() as f:
tname = f.name
with open(tname, "w+b") as f:
np.savetxt(f, X_test, delimiter=',')
f.flush()
pred_from_file = bst.predict(f.name)
pred_from_file = bst.predict(tname)
os.remove(tname)
self.assertEqual(len(pred_from_matr), len(pred_from_file))
for preds in zip(pred_from_matr, pred_from_file):
self.assertAlmostEqual(*preds, places=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment