Unverified Commit aacb4c8f authored by Fabio Sigrist's avatar Fabio Sigrist Committed by GitHub
Browse files

[R-package] fix protection stack imbalance and unprotected objects (fixes #4390) (#4391)



* [R-package] fix protection stack imbalance and unprotected objects issues

* [R-package] fix minor linting issues

* [ci][R-package] change timeout-minutes in valgrind test

* [R-package] remove extra space
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* [R-package] remove counter for number of protected objects

* Update .github/workflows/r_valgrind.yml
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent f62c4904
...@@ -28,15 +28,13 @@ ...@@ -28,15 +28,13 @@
#define R_API_BEGIN() \ #define R_API_BEGIN() \
try { try {
#define R_API_END() } \ #define R_API_END() } \
catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \ catch(std::exception& ex) { LGBM_SetLastError(ex.what()); } \
catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \ catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); } \
catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \ catch(...) { LGBM_SetLastError("unknown exception"); }
return R_NilValue;
#define CHECK_CALL(x) \ #define CHECK_CALL(x) \
if ((x) != 0) { \ if ((x) != 0) { \
Rf_error(LGBM_GetLastError()); \ Rf_error(LGBM_GetLastError()); \
return R_NilValue; \
} }
using LightGBM::Common::Split; using LightGBM::Common::Split;
...@@ -54,19 +52,20 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, ...@@ -54,19 +52,20 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters, SEXP parameters,
SEXP reference) { SEXP reference) {
SEXP ret; SEXP ret;
R_API_BEGIN();
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr; DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) { if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference); ref = R_ExternalPtrAddr(reference);
} }
CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)), const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
ref, &handle)); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(3);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
...@@ -78,27 +77,27 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, ...@@ -78,27 +77,27 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
SEXP parameters, SEXP parameters,
SEXP reference) { SEXP reference) {
SEXP ret; SEXP ret;
R_API_BEGIN();
const int* p_indptr = INTEGER(indptr); const int* p_indptr = INTEGER(indptr);
const int* p_indices = INTEGER(indices); const int* p_indices = INTEGER(indices);
const double* p_data = REAL(data); const double* p_data = REAL(data);
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr)); int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem)); int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row)); int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr; DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) { if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference); ref = R_ExternalPtrAddr(reference);
} }
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices, CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, CHAR(Rf_asChar(parameters)), ref, &handle)); nrow, parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetCreateFromMat_R(SEXP data, SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
...@@ -107,22 +106,23 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data, ...@@ -107,22 +106,23 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
SEXP parameters, SEXP parameters,
SEXP reference) { SEXP reference) {
SEXP ret; SEXP ret;
R_API_BEGIN();
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row)); int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col)); int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
double* p_mat = REAL(data); double* p_mat = REAL(data);
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr; DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) { if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference); ref = R_ExternalPtrAddr(reference);
} }
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
CHAR(Rf_asChar(parameters)), ref, &handle)); parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP LGBM_DatasetGetSubset_R(SEXP handle,
...@@ -130,7 +130,6 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, ...@@ -130,7 +130,6 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP len_used_row_indices, SEXP len_used_row_indices,
SEXP parameters) { SEXP parameters) {
SEXP ret; SEXP ret;
R_API_BEGIN();
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices)); int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len); std::vector<int32_t> idxvec(len);
// convert from one-based to zero-based index // convert from one-based to zero-based index
...@@ -138,36 +137,41 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, ...@@ -138,36 +137,41 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
for (int32_t i = 0; i < len; ++i) { for (int32_t i = 0; i < len; ++i) {
idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1); idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
} }
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle res = nullptr; DatasetHandle res = nullptr;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
idxvec.data(), len, CHAR(Rf_asChar(parameters)), idxvec.data(), len, parameters_ptr,
&res)); &res));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) { SEXP feature_names) {
R_API_BEGIN(); auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
std::vector<const char*> vec_sptr; std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size()); int len = static_cast<int>(vec_names.size());
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str()); vec_sptr.push_back(vec_names[i].c_str());
} }
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
vec_sptr.data(), len)); vec_sptr.data(), len));
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP feature_names; SEXP feature_names;
R_API_BEGIN();
int len = 0; int len = 0;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
R_API_END();
const size_t reserved_string_size = 256; const size_t reserved_string_size = 256;
std::vector<std::vector<char>> names(len); std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len); std::vector<char*> ptr_names(len);
...@@ -177,12 +181,14 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -177,12 +181,14 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
} }
int out_len; int out_len;
size_t required_string_size; size_t required_string_size;
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_DatasetGetFeatureNames( LGBM_DatasetGetFeatureNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
len, &out_len, len, &out_len,
reserved_string_size, &required_string_size, reserved_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
// if any feature names were larger than allocated size, // if any feature names were larger than allocated size,
// allow for a larger size and try again // allow for a larger size and try again
if (required_string_size > reserved_string_size) { if (required_string_size > reserved_string_size) {
...@@ -190,6 +196,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -190,6 +196,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
names[i].resize(required_string_size); names[i].resize(required_string_size);
ptr_names[i] = names[i].data(); ptr_names[i] = names[i].data();
} }
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_DatasetGetFeatureNames( LGBM_DatasetGetFeatureNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
...@@ -198,6 +205,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -198,6 +205,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
required_string_size, required_string_size,
&required_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
} }
CHECK_EQ(len, out_len); CHECK_EQ(len, out_len);
feature_names = PROTECT(Rf_allocVector(STRSXP, len)); feature_names = PROTECT(Rf_allocVector(STRSXP, len));
...@@ -206,15 +214,17 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -206,15 +214,17 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
} }
UNPROTECT(1); UNPROTECT(1);
return feature_names; return feature_names;
R_API_END();
} }
SEXP LGBM_DatasetSaveBinary_R(SEXP handle, SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) { SEXP filename) {
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
CHAR(Rf_asChar(filename)))); filename_ptr));
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_DatasetFree_R(SEXP handle) { SEXP LGBM_DatasetFree_R(SEXP handle) {
...@@ -224,15 +234,16 @@ SEXP LGBM_DatasetFree_R(SEXP handle) { ...@@ -224,15 +234,16 @@ SEXP LGBM_DatasetFree_R(SEXP handle) {
R_ClearExternalPtr(handle); R_ClearExternalPtr(handle);
} }
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_DatasetSetField_R(SEXP handle, SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data, SEXP field_data,
SEXP num_element) { SEXP num_element) {
R_API_BEGIN();
int len = Rf_asInteger(num_element); int len = Rf_asInteger(num_element);
const char* name = CHAR(Rf_asChar(field_name)); const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
R_API_BEGIN();
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len); std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024) #pragma omp parallel for schedule(static, 512) if (len >= 1024)
...@@ -251,18 +262,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle, ...@@ -251,18 +262,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32)); CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
} }
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_DatasetGetField_R(SEXP handle, SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data) { SEXP field_data) {
R_API_BEGIN(); const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
const char* name = CHAR(Rf_asChar(field_name));
int out_len = 0; int out_len = 0;
int out_type = 0; int out_type = 0;
const void* res; const void* res;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type)); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res); auto p_data = reinterpret_cast<const int32_t*>(res);
// convert from boundaries to size // convert from boundaries to size
...@@ -284,29 +296,37 @@ SEXP LGBM_DatasetGetField_R(SEXP handle, ...@@ -284,29 +296,37 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
} }
} }
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle, SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP out) { SEXP out) {
R_API_BEGIN(); const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
const char* name = CHAR(Rf_asChar(field_name));
int out_len = 0; int out_len = 0;
int out_type = 0; int out_type = 0;
const void* res; const void* res;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type)); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
out_len -= 1; out_len -= 1;
} }
INTEGER(out)[0] = out_len; INTEGER(out)[0] = out_len;
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params, SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
SEXP new_params) { SEXP new_params) {
const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params)))); CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
R_API_END(); R_API_END();
UNPROTECT(2);
return R_NilValue;
} }
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
...@@ -315,6 +335,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { ...@@ -315,6 +335,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow)); CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow; INTEGER(out)[0] = nrow;
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
...@@ -324,6 +345,7 @@ SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, ...@@ -324,6 +345,7 @@ SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = nfeature; INTEGER(out)[0] = nfeature;
R_API_END(); R_API_END();
return R_NilValue;
} }
// --- start Booster interfaces // --- start Booster interfaces
...@@ -339,45 +361,49 @@ SEXP LGBM_BoosterFree_R(SEXP handle) { ...@@ -339,45 +361,49 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
R_ClearExternalPtr(handle); R_ClearExternalPtr(handle);
} }
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterCreate_R(SEXP train_data, SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters) { SEXP parameters) {
SEXP ret; SEXP ret;
R_API_BEGIN(); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle)); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) { SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
SEXP ret; SEXP ret;
R_API_BEGIN();
int out_num_iterations = 0; int out_num_iterations = 0;
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle)); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
SEXP ret; SEXP ret;
R_API_BEGIN();
int out_num_iterations = 0; int out_num_iterations = 0;
const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str)));
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle)); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP LGBM_BoosterMerge_R(SEXP handle,
...@@ -385,6 +411,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle, ...@@ -385,6 +411,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle,
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle))); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle,
...@@ -392,6 +419,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle, ...@@ -392,6 +419,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle,
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
...@@ -399,13 +427,17 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, ...@@ -399,13 +427,17 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) { SEXP parameters) {
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(parameters)))); CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
...@@ -415,6 +447,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, ...@@ -415,6 +447,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class)); CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = num_class; INTEGER(out)[0] = num_class;
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
...@@ -422,6 +455,7 @@ SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { ...@@ -422,6 +455,7 @@ SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
...@@ -439,12 +473,14 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, ...@@ -439,12 +473,14 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
} }
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle))); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
...@@ -454,6 +490,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, ...@@ -454,6 +490,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration)); CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = out_iteration; INTEGER(out)[0] = out_iteration;
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
...@@ -462,6 +499,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, ...@@ -462,6 +499,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
...@@ -470,14 +508,15 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, ...@@ -470,14 +508,15 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP eval_names; SEXP eval_names;
R_API_BEGIN();
int len; int len;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
R_API_END();
const size_t reserved_string_size = 128; const size_t reserved_string_size = 128;
std::vector<std::vector<char>> names(len); std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len); std::vector<char*> ptr_names(len);
...@@ -488,12 +527,14 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -488,12 +527,14 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
int out_len; int out_len;
size_t required_string_size; size_t required_string_size;
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_BoosterGetEvalNames( LGBM_BoosterGetEvalNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
len, &out_len, len, &out_len,
reserved_string_size, &required_string_size, reserved_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
// if any eval names were larger than allocated size, // if any eval names were larger than allocated size,
// allow for a larger size and try again // allow for a larger size and try again
if (required_string_size > reserved_string_size) { if (required_string_size > reserved_string_size) {
...@@ -501,6 +542,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -501,6 +542,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
names[i].resize(required_string_size); names[i].resize(required_string_size);
ptr_names[i] = names[i].data(); ptr_names[i] = names[i].data();
} }
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_BoosterGetEvalNames( LGBM_BoosterGetEvalNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
...@@ -509,6 +551,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -509,6 +551,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
required_string_size, required_string_size,
&required_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
} }
CHECK_EQ(out_len, len); CHECK_EQ(out_len, len);
eval_names = PROTECT(Rf_allocVector(STRSXP, len)); eval_names = PROTECT(Rf_allocVector(STRSXP, len));
...@@ -517,7 +560,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -517,7 +560,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
} }
UNPROTECT(1); UNPROTECT(1);
return eval_names; return eval_names;
R_API_END();
} }
SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP LGBM_BoosterGetEval_R(SEXP handle,
...@@ -531,6 +573,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle, ...@@ -531,6 +573,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
CHECK_EQ(out_len, len); CHECK_EQ(out_len, len);
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
...@@ -541,6 +584,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, ...@@ -541,6 +584,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len)); CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len); INTEGER(out)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterGetPredict_R(SEXP handle, SEXP LGBM_BoosterGetPredict_R(SEXP handle,
...@@ -551,6 +595,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle, ...@@ -551,6 +595,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
R_API_END(); R_API_END();
return R_NilValue;
} }
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) { int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
...@@ -577,12 +622,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle, ...@@ -577,12 +622,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP parameter, SEXP parameter,
SEXP result_filename) { SEXP result_filename) {
R_API_BEGIN(); const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(data_filename)), R_API_BEGIN();
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
CHAR(Rf_asChar(result_filename)))); Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
result_filename_ptr));
R_API_END(); R_API_END();
UNPROTECT(3);
return R_NilValue;
} }
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
...@@ -600,6 +650,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, ...@@ -600,6 +650,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
INTEGER(out_len)[0] = static_cast<int>(len); INTEGER(out_len)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
return R_NilValue;
} }
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
...@@ -616,23 +667,24 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, ...@@ -616,23 +667,24 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP parameter, SEXP parameter,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = INTEGER(indptr); const int* p_indptr = INTEGER(indptr);
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices)); const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
const double* p_data = REAL(data); const double* p_data = REAL(data);
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr)); int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem)); int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row)); int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
p_indptr, C_API_DTYPE_INT32, p_indices, p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret)); nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_BoosterPredictForMat_R(SEXP handle, SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
...@@ -646,75 +698,82 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle, ...@@ -646,75 +698,82 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP parameter, SEXP parameter,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row)); int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col)); int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
const double* p_mat = REAL(data); const double* p_mat = REAL(data);
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
int64_t out_len; int64_t out_len;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP filename) { SEXP filename) {
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename)))); CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
R_API_END(); R_API_END();
UNPROTECT(1);
return R_NilValue;
} }
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type) {
SEXP model_str; SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
R_API_END();
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
R_API_END();
} }
model_str = PROTECT(Rf_allocVector(STRSXP, 1)); model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1); UNPROTECT(1);
return model_str; return model_str;
R_API_END();
} }
SEXP LGBM_BoosterDumpModel_R(SEXP handle, SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type) {
SEXP model_str; SEXP model_str;
R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
R_API_END();
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
R_API_END();
} }
model_str = PROTECT(Rf_allocVector(STRSXP, 1)); model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
UNPROTECT(1); UNPROTECT(1);
return model_str; return model_str;
R_API_END();
} }
// .Call() calls // .Call() calls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment