Unverified Commit 619c06d8 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix save_model_to_string for large model (#1080)

* use int64 for string

* [R] Fatal when CSC exceed int32.max
parent 53739670
......@@ -193,7 +193,9 @@ Dataset <- R6Class(
ref_handle)
} else if (is(private$raw_data, "dgCMatrix")) {
if (length(private$raw_data@p) > 2147483647) {
stop("Cannot support large CSC matrix")
}
# Are we using a dgCMatrix (sparsed matrix column compressed)
handle <- lgb.call("LGBM_DatasetCreateFromCSC_R",
ret = handle,
......
......@@ -127,7 +127,9 @@ Predictor <- R6Class(
private$params)
} else if (is(data, "dgCMatrix")) {
if (length(data@p) > 2147483647) {
stop("Cannot support large CSC matrix")
}
# Check if data is a dgCMatrix (sparse matrix, column compressed format)
preds <- lgb.call("LGBM_BoosterPredictForCSC_R",
ret = preds,
......
......@@ -677,8 +677,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
/*!
......@@ -692,8 +692,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
/*!
......
......@@ -1668,13 +1668,13 @@ class Booster(object):
if num_iteration <= 0:
num_iteration = self.best_iteration
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int(0)
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterSaveModelToString(
self.handle,
ctypes.c_int(num_iteration),
ctypes.c_int(buffer_len),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
......@@ -1685,7 +1685,7 @@ class Booster(object):
_safe_call(_LIB.LGBM_BoosterSaveModelToString(
self.handle,
ctypes.c_int(num_iteration),
ctypes.c_int(actual_len),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
return string_buffer.value.decode()
......@@ -1707,13 +1707,13 @@ class Booster(object):
if num_iteration <= 0:
num_iteration = self.best_iteration
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int(0)
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle,
ctypes.c_int(num_iteration),
ctypes.c_int(buffer_len),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
......@@ -1724,7 +1724,7 @@ class Booster(object):
_safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle,
ctypes.c_int(num_iteration),
ctypes.c_int(actual_len),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
return json.loads(string_buffer.value.decode())
......
......@@ -1140,13 +1140,13 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
#pragma warning(disable : 4996)
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(num_iteration);
*out_len = static_cast<int>(model.size()) + 1;
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str());
}
......@@ -1156,13 +1156,13 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
#pragma warning(disable : 4996)
int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration,
int buffer_len,
int* out_len,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(num_iteration);
*out_len = static_cast<int>(model.size()) + 1;
*out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str());
}
......
......@@ -601,14 +601,17 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int out_len = 0;
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
if (out_len <= INT32_MAX) {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
} else {
Log::Fatal("Don't support large model in R package.");
}
}
R_API_END();
}
......@@ -620,14 +623,17 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int out_len = 0;
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
if (out_len <= INT32_MAX) {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
} else {
Log::Fatal("Don't support large model in R package.");
}
}
R_API_END();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment