"src/vscode:/vscode.git/clone" did not exist on "d4c4d9ae828d827dfc7ad569f2da0efa6f871d9e"
Commit aeeef276 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[WIP] expose 'load/save model from/to string' to c_api (#241)

* expose save model to string to c_api

* add 'save-model-to-string' to python; use it to copy model

* remove boosting_type
parent 7dca0bb2
......@@ -259,6 +259,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
int* out_num_iterations,
BoosterHandle* out);
/*!
* \brief load an existing boosting from string
* \param model_str model string
* \param out_num_iterations number of iterations of this booster
* \param out handle of created Booster
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
const char* model_str,
int* out_num_iterations,
BoosterHandle* out);
/*!
* \brief free obj in handle
......@@ -558,6 +569,21 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration,
const char* filename);
/*!
* \brief save model to string
* \param handle handle
* \param num_iteration, <= 0 means save all
* \param buffer_len string buffer length, if buffer_len < out_len, re-allocate buffer
* \param out_len actual output length
* \param out_str string of model, need to pre-allocate memory before call this
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
char* out_str);
/*!
* \brief dump model to json
* \param handle handle
......
......@@ -1213,6 +1213,8 @@ class Booster(object):
self.pandas_categorical = eval(last_line[len('pandas_categorical:'):])
else:
self.pandas_categorical = None
elif 'model_str' in params:
self.__load_model_from_string(params['model_str'])
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')
......@@ -1224,9 +1226,10 @@ class Booster(object):
return self.__deepcopy__(None)
def __deepcopy__(self, _):
with _temp_file() as f:
self.save_model(f.name)
return Booster(model_file=f.name)
model_str = self.__save_model_to_string()
booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical
return booster
def __getstate__(self):
this = self.__dict__.copy()
......@@ -1234,20 +1237,16 @@ class Booster(object):
this.pop('train_set', None)
this.pop('valid_sets', None)
if handle is not None:
with _temp_file() as f:
self.save_model(f.name)
this["handle"] = f.readlines()
this["handle"] = self.__save_model_to_string()
return this
def __setstate__(self, state):
model = state['handle']
if model is not None:
model_str = state.get('handle', None)
if model_str is not None:
handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0)
with _temp_file() as f:
f.writelines(model)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
c_str(f.name),
_safe_call(_LIB.LGBM_BoosterLoadModelFromString(
c_str(model_str),
ctypes.byref(out_num_iterations),
ctypes.byref(handle)))
state['handle'] = handle
......@@ -1472,6 +1471,46 @@ class Booster(object):
with open(filename, 'a') as f:
f.write('\npandas_categorical:' + repr(self.pandas_categorical))
def __load_model_from_string(self, model_str):
"""[Private] Load model from string"""
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString(
c_str(model_str),
ctypes.byref(out_num_iterations),
ctypes.byref(self.handle)))
out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self.handle,
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
def __save_model_to_string(self, num_iteration=-1):
"""[Private] Save model to string"""
if num_iteration <= 0:
num_iteration = self.best_iteration
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterSaveModelToString(
self.handle,
ctypes.c_int(num_iteration),
ctypes.c_int(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
'''if buffer length is not long enough, re-allocate a buffer'''
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterSaveModelToString(
self.handle,
ctypes.c_int(num_iteration),
ctypes.c_int(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
return string_buffer.value.decode()
def dump_model(self, num_iteration=-1):
"""
Dump model to json format
......
......@@ -31,6 +31,10 @@ public:
boosting_.reset(Boosting::CreateBoosting(filename));
}
Booster() {
boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
}
Booster(const Dataset* train_data,
const char* parameters) {
auto param = ConfigBase::Str2Map(parameters);
......@@ -181,6 +185,14 @@ public:
boosting_->SaveModelToFile(num_iteration, filename);
}
void LoadModelFromString(const char* model_str) {
boosting_->LoadModelFromString(model_str);
}
std::string SaveModelToString(int num_iteration) {
return boosting_->SaveModelToString(num_iteration);
}
std::string DumpModel(int num_iteration) {
return boosting_->DumpModel(num_iteration);
}
......@@ -605,6 +617,18 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
const char* model_str,
int* out_num_iterations,
BoosterHandle* out) {
API_BEGIN();
auto ret = std::unique_ptr<Booster>(new Booster());
ret->LoadModelFromString(model_str);
*out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
*out = ret.release();
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN();
delete reinterpret_cast<Booster*>(handle);
......@@ -889,6 +913,21 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(num_iteration);
*out_len = static_cast<int>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str());
}
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment