Unverified Commit eda0d3ca authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[R-package] Fix R memory leaks (fixes #4282, fixes #3462) (#4597)

* fix R memory leaks

* attempt at solving linter complaints

* fix compilation on windows

* move R_API_BEGIN to correct place

* make sure exception objects reach out of scope

* better way to solve rchk complaints

* remove goto statement
parent 2c8bb45b
...@@ -25,17 +25,59 @@ ...@@ -25,17 +25,59 @@
#define COL_MAJOR (0) #define COL_MAJOR (0)
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);
#define R_API_BEGIN() \ #define R_API_BEGIN() \
try { try {
#define R_API_END() } \ #define R_API_END() } \
catch(std::exception& ex) { LGBM_SetLastError(ex.what()); } \ catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); } \ catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
catch(...) { LGBM_SetLastError("unknown exception"); } catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
catch(...) { Rf_error("unknown exception"); } \
Rf_error(R_errmsg_buffer); \
return R_NilValue; /* <- won't be reached */
#define CHECK_CALL(x) \ #define CHECK_CALL(x) \
if ((x) != 0) { \ if ((x) != 0) { \
Rf_error(LGBM_GetLastError()); \ throw std::runtime_error(LGBM_GetLastError()); \
}
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}
void LGBM_R_save_exception_msg(const std::string &err) {
std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}
SEXP wrapped_R_string(void *len) {
return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}
SEXP wrapped_Rf_mkChar(void *txt) {
return Rf_mkChar(reinterpret_cast<char*>(txt));
}
void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
if (jump) {
LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
throw err;
} }
}
SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}
using LightGBM::Common::Split; using LightGBM::Common::Split;
using LightGBM::Log; using LightGBM::Log;
...@@ -51,6 +93,7 @@ void _DatasetFinalizer(SEXP handle) { ...@@ -51,6 +93,7 @@ void _DatasetFinalizer(SEXP handle) {
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters, SEXP parameters,
SEXP reference) { SEXP reference) {
R_API_BEGIN();
SEXP ret; SEXP ret;
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr; DatasetHandle ref = nullptr;
...@@ -59,13 +102,12 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, ...@@ -59,13 +102,12 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
} }
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle)); CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(3); UNPROTECT(3);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
...@@ -76,6 +118,7 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, ...@@ -76,6 +118,7 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
SEXP num_row, SEXP num_row,
SEXP parameters, SEXP parameters,
SEXP reference) { SEXP reference) {
R_API_BEGIN();
SEXP ret; SEXP ret;
const int* p_indptr = INTEGER(indptr); const int* p_indptr = INTEGER(indptr);
const int* p_indices = INTEGER(indices); const int* p_indices = INTEGER(indices);
...@@ -89,15 +132,14 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, ...@@ -89,15 +132,14 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
if (!Rf_isNull(reference)) { if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference); ref = R_ExternalPtrAddr(reference);
} }
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices, CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, parameters_ptr, ref, &handle)); nrow, parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetCreateFromMat_R(SEXP data, SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
...@@ -105,6 +147,7 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data, ...@@ -105,6 +147,7 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
SEXP num_col, SEXP num_col,
SEXP parameters, SEXP parameters,
SEXP reference) { SEXP reference) {
R_API_BEGIN();
SEXP ret; SEXP ret;
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row)); int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col)); int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
...@@ -115,20 +158,20 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data, ...@@ -115,20 +158,20 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
if (!Rf_isNull(reference)) { if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference); ref = R_ExternalPtrAddr(reference);
} }
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
parameters_ptr, ref, &handle)); parameters_ptr, ref, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP used_row_indices, SEXP used_row_indices,
SEXP len_used_row_indices, SEXP len_used_row_indices,
SEXP parameters) { SEXP parameters) {
R_API_BEGIN();
SEXP ret; SEXP ret;
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices)); int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len); std::vector<int32_t> idxvec(len);
...@@ -139,39 +182,38 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, ...@@ -139,39 +182,38 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
} }
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
DatasetHandle res = nullptr; DatasetHandle res = nullptr;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
idxvec.data(), len, parameters_ptr, idxvec.data(), len, parameters_ptr,
&res)); &res));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) { SEXP feature_names) {
R_API_BEGIN();
auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t'); auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
std::vector<const char*> vec_sptr; std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size()); int len = static_cast<int>(vec_names.size());
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str()); vec_sptr.push_back(vec_names[i].c_str());
} }
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
vec_sptr.data(), len)); vec_sptr.data(), len));
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
SEXP feature_names; SEXP feature_names;
int len = 0; int len = 0;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
R_API_END();
const size_t reserved_string_size = 256; const size_t reserved_string_size = 256;
std::vector<std::vector<char>> names(len); std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len); std::vector<char*> ptr_names(len);
...@@ -181,14 +223,12 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -181,14 +223,12 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
} }
int out_len; int out_len;
size_t required_string_size; size_t required_string_size;
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_DatasetGetFeatureNames( LGBM_DatasetGetFeatureNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
len, &out_len, len, &out_len,
reserved_string_size, &required_string_size, reserved_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
// if any feature names were larger than allocated size, // if any feature names were larger than allocated size,
// allow for a larger size and try again // allow for a larger size and try again
if (required_string_size > reserved_string_size) { if (required_string_size > reserved_string_size) {
...@@ -196,7 +236,6 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -196,7 +236,6 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
names[i].resize(required_string_size); names[i].resize(required_string_size);
ptr_names[i] = names[i].data(); ptr_names[i] = names[i].data();
} }
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_DatasetGetFeatureNames( LGBM_DatasetGetFeatureNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
...@@ -205,26 +244,26 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { ...@@ -205,26 +244,26 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
required_string_size, required_string_size,
&required_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
} }
CHECK_EQ(len, out_len); CHECK_EQ(len, out_len);
feature_names = PROTECT(Rf_allocVector(STRSXP, len)); feature_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
SET_STRING_ELT(feature_names, i, Rf_mkChar(ptr_names[i])); SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
} }
UNPROTECT(1); UNPROTECT(2);
return feature_names; return feature_names;
R_API_END();
} }
SEXP LGBM_DatasetSaveBinary_R(SEXP handle, SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) { SEXP filename) {
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
R_API_BEGIN(); R_API_BEGIN();
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
filename_ptr)); filename_ptr));
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetFree_R(SEXP handle) { SEXP LGBM_DatasetFree_R(SEXP handle) {
...@@ -233,17 +272,17 @@ SEXP LGBM_DatasetFree_R(SEXP handle) { ...@@ -233,17 +272,17 @@ SEXP LGBM_DatasetFree_R(SEXP handle) {
CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle))); CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
R_ClearExternalPtr(handle); R_ClearExternalPtr(handle);
} }
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetSetField_R(SEXP handle, SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data, SEXP field_data,
SEXP num_element) { SEXP num_element) {
R_API_BEGIN();
int len = Rf_asInteger(num_element); int len = Rf_asInteger(num_element);
const char* name = CHAR(PROTECT(Rf_asChar(field_name))); const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
R_API_BEGIN();
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
std::vector<int32_t> vec(len); std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024) #pragma omp parallel for schedule(static, 512) if (len >= 1024)
...@@ -261,19 +300,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle, ...@@ -261,19 +300,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
} }
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32)); CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
} }
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetGetField_R(SEXP handle, SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data) { SEXP field_data) {
R_API_BEGIN();
const char* name = CHAR(PROTECT(Rf_asChar(field_name))); const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0; int out_len = 0;
int out_type = 0; int out_type = 0;
const void* res; const void* res;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type)); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res); auto p_data = reinterpret_cast<const int32_t*>(res);
...@@ -295,57 +334,57 @@ SEXP LGBM_DatasetGetField_R(SEXP handle, ...@@ -295,57 +334,57 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
REAL(field_data)[i] = p_data[i]; REAL(field_data)[i] = p_data[i];
} }
} }
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle, SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP out) { SEXP out) {
R_API_BEGIN();
const char* name = CHAR(PROTECT(Rf_asChar(field_name))); const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0; int out_len = 0;
int out_type = 0; int out_type = 0;
const void* res; const void* res;
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type)); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
out_len -= 1; out_len -= 1;
} }
INTEGER(out)[0] = out_len; INTEGER(out)[0] = out_len;
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params, SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
SEXP new_params) { SEXP new_params) {
R_API_BEGIN();
const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params))); const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params))); const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
R_API_BEGIN();
CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr)); CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
R_API_END();
UNPROTECT(2); UNPROTECT(2);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
int nrow;
R_API_BEGIN(); R_API_BEGIN();
int nrow;
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow)); CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow; INTEGER(out)[0] = nrow;
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
SEXP out) { SEXP out) {
int nfeature;
R_API_BEGIN(); R_API_BEGIN();
int nfeature;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = nfeature; INTEGER(out)[0] = nfeature;
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
// --- start Booster interfaces // --- start Booster interfaces
...@@ -360,110 +399,110 @@ SEXP LGBM_BoosterFree_R(SEXP handle) { ...@@ -360,110 +399,110 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle))); CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
R_ClearExternalPtr(handle); R_ClearExternalPtr(handle);
} }
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterCreate_R(SEXP train_data, SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters) { SEXP parameters) {
R_API_BEGIN();
SEXP ret; SEXP ret;
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle)); CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) { SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
R_API_BEGIN();
SEXP ret; SEXP ret;
int out_num_iterations = 0; int out_num_iterations = 0;
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
R_API_BEGIN();
SEXP ret; SEXP ret;
int out_num_iterations = 0; int out_num_iterations = 0;
const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str))); const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str)));
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
R_API_END();
ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(2); UNPROTECT(2);
return ret; return ret;
R_API_END();
} }
SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP other_handle) { SEXP other_handle) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle))); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP valid_data) { SEXP valid_data) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP train_data) { SEXP train_data) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) { SEXP parameters) {
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
R_API_BEGIN(); R_API_BEGIN();
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr)); CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
SEXP out) { SEXP out) {
int num_class;
R_API_BEGIN(); R_API_BEGIN();
int num_class;
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class)); CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = num_class; INTEGER(out)[0] = num_class;
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
int is_finished = 0;
R_API_BEGIN(); R_API_BEGIN();
int is_finished = 0;
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
SEXP grad, SEXP grad,
SEXP hess, SEXP hess,
SEXP len) { SEXP len) {
int is_finished = 0;
R_API_BEGIN(); R_API_BEGIN();
int is_finished = 0;
int int_len = Rf_asInteger(len); int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len); std::vector<float> tgrad(int_len), thess(int_len);
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024) #pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
...@@ -472,25 +511,25 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, ...@@ -472,25 +511,25 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
thess[j] = static_cast<float>(REAL(hess)[j]); thess[j] = static_cast<float>(REAL(hess)[j]);
} }
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle))); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP out) { SEXP out) {
int out_iteration;
R_API_BEGIN(); R_API_BEGIN();
int out_iteration;
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration)); CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = out_iteration; INTEGER(out)[0] = out_iteration;
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
...@@ -498,8 +537,8 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, ...@@ -498,8 +537,8 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
...@@ -507,16 +546,16 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, ...@@ -507,16 +546,16 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
SEXP eval_names; SEXP eval_names;
int len; int len;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
R_API_END();
const size_t reserved_string_size = 128; const size_t reserved_string_size = 128;
std::vector<std::vector<char>> names(len); std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len); std::vector<char*> ptr_names(len);
...@@ -527,14 +566,12 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -527,14 +566,12 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
int out_len; int out_len;
size_t required_string_size; size_t required_string_size;
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_BoosterGetEvalNames( LGBM_BoosterGetEvalNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
len, &out_len, len, &out_len,
reserved_string_size, &required_string_size, reserved_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
// if any eval names were larger than allocated size, // if any eval names were larger than allocated size,
// allow for a larger size and try again // allow for a larger size and try again
if (required_string_size > reserved_string_size) { if (required_string_size > reserved_string_size) {
...@@ -542,7 +579,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -542,7 +579,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
names[i].resize(required_string_size); names[i].resize(required_string_size);
ptr_names[i] = names[i].data(); ptr_names[i] = names[i].data();
} }
R_API_BEGIN();
CHECK_CALL( CHECK_CALL(
LGBM_BoosterGetEvalNames( LGBM_BoosterGetEvalNames(
R_ExternalPtrAddr(handle), R_ExternalPtrAddr(handle),
...@@ -551,15 +587,15 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { ...@@ -551,15 +587,15 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
required_string_size, required_string_size,
&required_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
R_API_END();
} }
CHECK_EQ(out_len, len); CHECK_EQ(out_len, len);
eval_names = PROTECT(Rf_allocVector(STRSXP, len)); eval_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
SET_STRING_ELT(eval_names, i, Rf_mkChar(ptr_names[i])); SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
} }
UNPROTECT(1); UNPROTECT(2);
return eval_names; return eval_names;
R_API_END();
} }
SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP LGBM_BoosterGetEval_R(SEXP handle,
...@@ -572,8 +608,8 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle, ...@@ -572,8 +608,8 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
int out_len; int out_len;
CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
CHECK_EQ(out_len, len); CHECK_EQ(out_len, len);
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
...@@ -583,8 +619,8 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, ...@@ -583,8 +619,8 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
int64_t len; int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len)); CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len); INTEGER(out)[0] = static_cast<int>(len);
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterGetPredict_R(SEXP handle, SEXP LGBM_BoosterGetPredict_R(SEXP handle,
...@@ -594,8 +630,8 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle, ...@@ -594,8 +630,8 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) { int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
...@@ -622,17 +658,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle, ...@@ -622,17 +658,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP parameter, SEXP parameter,
SEXP result_filename) { SEXP result_filename) {
R_API_BEGIN();
const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename))); const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename))); const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr, CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
result_filename_ptr)); result_filename_ptr));
R_API_END();
UNPROTECT(3); UNPROTECT(3);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
...@@ -649,8 +685,8 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, ...@@ -649,8 +685,8 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row), CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
INTEGER(out_len)[0] = static_cast<int>(len); INTEGER(out_len)[0] = static_cast<int>(len);
R_API_END();
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
...@@ -667,6 +703,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, ...@@ -667,6 +703,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP parameter, SEXP parameter,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = INTEGER(indptr); const int* p_indptr = INTEGER(indptr);
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices)); const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
...@@ -677,14 +714,13 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, ...@@ -677,14 +714,13 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
p_indptr, C_API_DTYPE_INT32, p_indices, p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret)); nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterPredictForMat_R(SEXP handle, SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
...@@ -698,6 +734,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle, ...@@ -698,6 +734,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP parameter, SEXP parameter,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row)); int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col)); int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
...@@ -705,75 +742,72 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle, ...@@ -705,75 +742,72 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
int64_t out_len; int64_t out_len;
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle), CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP filename) { SEXP filename) {
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
R_API_BEGIN(); R_API_BEGIN();
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr)); CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
R_API_END();
UNPROTECT(1); UNPROTECT(1);
return R_NilValue; return R_NilValue;
R_API_END();
} }
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
SEXP model_str; SEXP model_str;
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
R_API_END();
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
R_API_END();
} }
model_str = PROTECT(Rf_allocVector(STRSXP, 1)); model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(1); UNPROTECT(2);
return model_str; return model_str;
R_API_END();
} }
SEXP LGBM_BoosterDumpModel_R(SEXP handle, SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
SEXP model_str; SEXP model_str;
int64_t out_len = 0; int64_t out_len = 0;
int64_t buf_len = 1024 * 1024; int64_t buf_len = 1024 * 1024;
int num_iter = Rf_asInteger(num_iteration); int num_iter = Rf_asInteger(num_iteration);
int importance_type = Rf_asInteger(feature_importance_type); int importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
R_API_END();
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
R_API_END();
} }
model_str = PROTECT(Rf_allocVector(STRSXP, 1)); model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(1); UNPROTECT(2);
return model_str; return model_str;
R_API_END();
} }
// .Call() calls // .Call() calls
......
...@@ -1779,7 +1779,6 @@ test_that("lgb.train() fit on linearly-relatead data improves when using linear ...@@ -1779,7 +1779,6 @@ test_that("lgb.train() fit on linearly-relatead data improves when using linear
test_that("lgb.train() w/ linear learner fails already-constructed dataset with linear=false", { test_that("lgb.train() w/ linear learner fails already-constructed dataset with linear=false", {
testthat::skip("Skipping this test because it causes issues for valgrind")
set.seed(708L) set.seed(708L)
params <- list( params <- list(
objective = "regression" objective = "regression"
......
...@@ -693,7 +693,6 @@ test_that("Saving a model with different feature importance types works", { ...@@ -693,7 +693,6 @@ test_that("Saving a model with different feature importance types works", {
}) })
test_that("Saving a model with unknown importance type fails", { test_that("Saving a model with unknown importance type fails", {
testthat::skip("Skipping this test because it causes issues for valgrind")
set.seed(708L) set.seed(708L)
data(agaricus.train, package = "lightgbm") data(agaricus.train, package = "lightgbm")
train <- agaricus.train train <- agaricus.train
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#ifdef LGB_R_BUILD #ifdef LGB_R_BUILD
#define R_NO_REMAP #define R_NO_REMAP
#define R_USE_C99_IN_CXX #define R_USE_C99_IN_CXX
#include <R_ext/Error.h>
#include <R_ext/Print.h> #include <R_ext/Print.h>
extern "C" void R_FlushConsole(void);
#endif #endif
namespace LightGBM { namespace LightGBM {
...@@ -124,7 +124,8 @@ class Log { ...@@ -124,7 +124,8 @@ class Log {
fprintf(stderr, "[LightGBM] [Fatal] %s\n", str_buf); fprintf(stderr, "[LightGBM] [Fatal] %s\n", str_buf);
fflush(stderr); fflush(stderr);
#else #else
Rf_error("[LightGBM] [Fatal] %s\n", str_buf); REprintf("[LightGBM] [Fatal] %s\n", str_buf);
R_FlushConsole();
#endif #endif
throw std::runtime_error(std::string(str_buf)); throw std::runtime_error(std::string(str_buf));
} }
...@@ -154,6 +155,7 @@ class Log { ...@@ -154,6 +155,7 @@ class Log {
Rprintf("[LightGBM] [%s] ", level_str); Rprintf("[LightGBM] [%s] ", level_str);
Rvprintf(format, val); Rvprintf(format, val);
Rprintf("\n"); Rprintf("\n");
R_FlushConsole();
#endif #endif
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment