Unverified Commit bb88d92e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] Use R standard routines to access numeric and integer array data in C++ (#4247)

* real pointer for matrix

* remove R_REAL_PTR

* remove R_INT_PTR

* add test
parent aedfdd0d
...@@ -98,10 +98,6 @@ typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN; ...@@ -98,10 +98,6 @@ typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN;
#define R_CHAR_PTR(x) (reinterpret_cast<char*>DATAPTR(x)) #define R_CHAR_PTR(x) (reinterpret_cast<char*>DATAPTR(x))
#define R_INT_PTR(x) (reinterpret_cast<int*> DATAPTR(x))
#define R_REAL_PTR(x) (reinterpret_cast<double*> DATAPTR(x))
#define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0) #define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0)
// 64bit pointer // 64bit pointer
......
...@@ -43,11 +43,11 @@ using LightGBM::Common::Join; ...@@ -43,11 +43,11 @@ using LightGBM::Common::Join;
using LightGBM::Common::Split; using LightGBM::Common::Split;
using LightGBM::Log; using LightGBM::Log;
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, LGBM_SE actual_len, size_t str_len) { LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, SEXP actual_len, size_t str_len) {
if (str_len > INT32_MAX) { if (str_len > INT32_MAX) {
Log::Fatal("Don't support large string in R-package"); Log::Fatal("Don't support large string in R-package");
} }
R_INT_PTR(actual_len)[0] = static_cast<int>(str_len); INTEGER(actual_len)[0] = static_cast<int>(str_len);
if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) { if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
return dest; return dest;
} }
...@@ -76,9 +76,9 @@ SEXP LGBM_DatasetCreateFromFile_R(LGBM_SE filename, ...@@ -76,9 +76,9 @@ SEXP LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr, SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
LGBM_SE indices, SEXP indices,
LGBM_SE data, SEXP data,
SEXP num_indptr, SEXP num_indptr,
SEXP nelem, SEXP nelem,
SEXP num_row, SEXP num_row,
...@@ -86,9 +86,9 @@ SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr, ...@@ -86,9 +86,9 @@ SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
LGBM_SE reference, LGBM_SE reference,
LGBM_SE out) { LGBM_SE out) {
R_API_BEGIN(); R_API_BEGIN();
const int* p_indptr = R_INT_PTR(indptr); const int* p_indptr = INTEGER(indptr);
const int* p_indices = R_INT_PTR(indices); const int* p_indices = INTEGER(indices);
const double* p_data = R_REAL_PTR(data); const double* p_data = REAL(data);
int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr)); int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem)); int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
...@@ -101,7 +101,7 @@ SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr, ...@@ -101,7 +101,7 @@ SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data, SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
LGBM_SE parameters, LGBM_SE parameters,
...@@ -110,7 +110,7 @@ SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data, ...@@ -110,7 +110,7 @@ SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data,
R_API_BEGIN(); R_API_BEGIN();
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row)); int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col)); int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
double* p_mat = R_REAL_PTR(data); double* p_mat = REAL(data);
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle)); R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
...@@ -119,7 +119,7 @@ SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data, ...@@ -119,7 +119,7 @@ SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data,
} }
SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle, SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
LGBM_SE used_row_indices, SEXP used_row_indices,
SEXP len_used_row_indices, SEXP len_used_row_indices,
LGBM_SE parameters, LGBM_SE parameters,
LGBM_SE out) { LGBM_SE out) {
...@@ -129,7 +129,7 @@ SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle, ...@@ -129,7 +129,7 @@ SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
// convert from one-based to zero-based index // convert from one-based to zero-based index
#pragma omp parallel for schedule(static, 512) if (len >= 1024) #pragma omp parallel for schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1; idxvec[i] = INTEGER(used_row_indices)[i] - 1;
} }
DatasetHandle res = nullptr; DatasetHandle res = nullptr;
CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle), CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
...@@ -155,7 +155,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle, ...@@ -155,7 +155,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle, SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
SEXP buf_len, SEXP buf_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE feature_names) { LGBM_SE feature_names) {
R_API_BEGIN(); R_API_BEGIN();
int len = 0; int len = 0;
...@@ -201,7 +201,7 @@ SEXP LGBM_DatasetFree_R(LGBM_SE handle) { ...@@ -201,7 +201,7 @@ SEXP LGBM_DatasetFree_R(LGBM_SE handle) {
SEXP LGBM_DatasetSetField_R(LGBM_SE handle, SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
LGBM_SE field_name, LGBM_SE field_name,
LGBM_SE field_data, SEXP field_data,
SEXP num_element) { SEXP num_element) {
R_API_BEGIN(); R_API_BEGIN();
int len = static_cast<int>(Rf_asInteger(num_element)); int len = static_cast<int>(Rf_asInteger(num_element));
...@@ -210,16 +210,16 @@ SEXP LGBM_DatasetSetField_R(LGBM_SE handle, ...@@ -210,16 +210,16 @@ SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
std::vector<int32_t> vec(len); std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024) #pragma omp parallel for schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]); vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
} }
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32)); CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
} else if (!strcmp("init_score", name)) { } else if (!strcmp("init_score", name)) {
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64)); CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
} else { } else {
std::vector<float> vec(len); std::vector<float> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024) #pragma omp parallel for schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]); vec[i] = static_cast<float>(REAL(field_data)[i]);
} }
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32)); CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
} }
...@@ -228,7 +228,7 @@ SEXP LGBM_DatasetSetField_R(LGBM_SE handle, ...@@ -228,7 +228,7 @@ SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
SEXP LGBM_DatasetGetField_R(LGBM_SE handle, SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
LGBM_SE field_name, LGBM_SE field_name,
LGBM_SE field_data) { SEXP field_data) {
R_API_BEGIN(); R_API_BEGIN();
const char* name = R_CHAR_PTR(field_name); const char* name = R_CHAR_PTR(field_name);
int out_len = 0; int out_len = 0;
...@@ -241,19 +241,19 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle, ...@@ -241,19 +241,19 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
// convert from boundaries to size // convert from boundaries to size
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024) #pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len - 1; ++i) { for (int i = 0; i < out_len - 1; ++i) {
R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i]; INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
} }
} else if (!strcmp("init_score", name)) { } else if (!strcmp("init_score", name)) {
auto p_data = reinterpret_cast<const double*>(res); auto p_data = reinterpret_cast<const double*>(res);
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024) #pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) { for (int i = 0; i < out_len; ++i) {
R_REAL_PTR(field_data)[i] = p_data[i]; REAL(field_data)[i] = p_data[i];
} }
} else { } else {
auto p_data = reinterpret_cast<const float*>(res); auto p_data = reinterpret_cast<const float*>(res);
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024) #pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
for (int i = 0; i < out_len; ++i) { for (int i = 0; i < out_len; ++i) {
R_REAL_PTR(field_data)[i] = p_data[i]; REAL(field_data)[i] = p_data[i];
} }
} }
R_API_END(); R_API_END();
...@@ -261,7 +261,7 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle, ...@@ -261,7 +261,7 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle, SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
LGBM_SE field_name, LGBM_SE field_name,
LGBM_SE out) { SEXP out) {
R_API_BEGIN(); R_API_BEGIN();
const char* name = R_CHAR_PTR(field_name); const char* name = R_CHAR_PTR(field_name);
int out_len = 0; int out_len = 0;
...@@ -271,7 +271,7 @@ SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle, ...@@ -271,7 +271,7 @@ SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
out_len -= 1; out_len -= 1;
} }
R_INT_PTR(out)[0] = static_cast<int>(out_len); INTEGER(out)[0] = static_cast<int>(out_len);
R_API_END(); R_API_END();
} }
...@@ -282,20 +282,20 @@ SEXP LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params, ...@@ -282,20 +282,20 @@ SEXP LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params,
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out) { SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, SEXP out) {
int nrow; int nrow;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow)); CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
R_INT_PTR(out)[0] = static_cast<int>(nrow); INTEGER(out)[0] = static_cast<int>(nrow);
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle, SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
LGBM_SE out) { SEXP out) {
int nfeature; int nfeature;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
R_INT_PTR(out)[0] = static_cast<int>(nfeature); INTEGER(out)[0] = static_cast<int>(nfeature);
R_API_END(); R_API_END();
} }
...@@ -369,11 +369,11 @@ SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle, ...@@ -369,11 +369,11 @@ SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle,
} }
SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle, SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
LGBM_SE out) { SEXP out) {
int num_class; int num_class;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class)); CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
R_INT_PTR(out)[0] = static_cast<int>(num_class); INTEGER(out)[0] = static_cast<int>(num_class);
R_API_END(); R_API_END();
} }
...@@ -385,8 +385,8 @@ SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) { ...@@ -385,8 +385,8 @@ SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) {
} }
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle, SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
LGBM_SE grad, SEXP grad,
LGBM_SE hess, SEXP hess,
SEXP len) { SEXP len) {
int is_finished = 0; int is_finished = 0;
R_API_BEGIN(); R_API_BEGIN();
...@@ -394,8 +394,8 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle, ...@@ -394,8 +394,8 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
std::vector<float> tgrad(int_len), thess(int_len); std::vector<float> tgrad(int_len), thess(int_len);
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024) #pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
for (int j = 0; j < int_len; ++j) { for (int j = 0; j < int_len; ++j) {
tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]); tgrad[j] = static_cast<float>(REAL(grad)[j]);
thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]); thess[j] = static_cast<float>(REAL(hess)[j]);
} }
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
R_API_END(); R_API_END();
...@@ -408,33 +408,33 @@ SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) { ...@@ -408,33 +408,33 @@ SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) {
} }
SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle, SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
LGBM_SE out) { SEXP out) {
int out_iteration; int out_iteration;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration)); CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
R_INT_PTR(out)[0] = static_cast<int>(out_iteration); INTEGER(out)[0] = static_cast<int>(out_iteration);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle, SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
LGBM_SE out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = R_REAL_PTR(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle, SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
LGBM_SE out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = R_REAL_PTR(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle, SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
SEXP buf_len, SEXP buf_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE eval_names) { LGBM_SE eval_names) {
R_API_BEGIN(); R_API_BEGIN();
int len; int len;
...@@ -465,11 +465,11 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle, ...@@ -465,11 +465,11 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
SEXP LGBM_BoosterGetEval_R(LGBM_SE handle, SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
SEXP data_idx, SEXP data_idx,
LGBM_SE out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
int len; int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len)); CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
double* ptr_ret = R_REAL_PTR(out_result); double* ptr_ret = REAL(out_result);
int out_len; int out_len;
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
CHECK_EQ(out_len, len); CHECK_EQ(out_len, len);
...@@ -478,19 +478,19 @@ SEXP LGBM_BoosterGetEval_R(LGBM_SE handle, ...@@ -478,19 +478,19 @@ SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle, SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
SEXP data_idx, SEXP data_idx,
LGBM_SE out) { SEXP out) {
R_API_BEGIN(); R_API_BEGIN();
int64_t len; int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len)); CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len));
R_INT_PTR(out)[0] = static_cast<int>(len); INTEGER(out)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle, SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle,
SEXP data_idx, SEXP data_idx,
LGBM_SE out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = R_REAL_PTR(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
R_API_END(); R_API_END();
...@@ -535,20 +535,20 @@ SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, ...@@ -535,20 +535,20 @@ SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
SEXP is_predcontrib, SEXP is_predcontrib,
SEXP start_iteration, SEXP start_iteration,
SEXP num_iteration, SEXP num_iteration,
LGBM_SE out_len) { SEXP out_len) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0; int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row), CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row),
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
R_INT_PTR(out_len)[0] = static_cast<int>(len); INTEGER(out_len)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle, SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE indptr, SEXP indptr,
LGBM_SE indices, SEXP indices,
LGBM_SE data, SEXP data,
SEXP num_indptr, SEXP num_indptr,
SEXP nelem, SEXP nelem,
SEXP num_row, SEXP num_row,
...@@ -558,18 +558,18 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -558,18 +558,18 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
SEXP start_iteration, SEXP start_iteration,
SEXP num_iteration, SEXP num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = R_INT_PTR(indptr); const int* p_indptr = INTEGER(indptr);
const int* p_indices = R_INT_PTR(indices); const int* p_indices = INTEGER(indices);
const double* p_data = R_REAL_PTR(data); const double* p_data = REAL(data);
int64_t nindptr = Rf_asInteger(num_indptr); int64_t nindptr = Rf_asInteger(num_indptr);
int64_t ndata = Rf_asInteger(nelem); int64_t ndata = Rf_asInteger(nelem);
int64_t nrow = Rf_asInteger(num_row); int64_t nrow = Rf_asInteger(num_row);
double* ptr_ret = R_REAL_PTR(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices, p_indptr, C_API_DTYPE_INT32, p_indices,
...@@ -579,7 +579,7 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -579,7 +579,7 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
} }
SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle, SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE data, SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
SEXP is_rawscore, SEXP is_rawscore,
...@@ -588,15 +588,15 @@ SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -588,15 +588,15 @@ SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
SEXP start_iteration, SEXP start_iteration,
SEXP num_iteration, SEXP num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = Rf_asInteger(num_row); int32_t nrow = Rf_asInteger(num_row);
int32_t ncol = Rf_asInteger(num_col); int32_t ncol = Rf_asInteger(num_col);
const double* p_mat = R_REAL_PTR(data); const double* p_mat = REAL(data);
double* ptr_ret = R_REAL_PTR(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
...@@ -618,7 +618,7 @@ SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle, ...@@ -618,7 +618,7 @@ SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP buffer_len, SEXP buffer_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE out_str) { LGBM_SE out_str) {
R_API_BEGIN(); R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
...@@ -633,7 +633,7 @@ SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle, ...@@ -633,7 +633,7 @@ SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP buffer_len, SEXP buffer_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE out_str) { LGBM_SE out_str) {
R_API_BEGIN(); R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
......
...@@ -53,9 +53,9 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromFile_R( ...@@ -53,9 +53,9 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromFile_R(
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R(
LGBM_SE indptr, SEXP indptr,
LGBM_SE indices, SEXP indices,
LGBM_SE data, SEXP data,
SEXP num_indptr, SEXP num_indptr,
SEXP nelem, SEXP nelem,
SEXP num_row, SEXP num_row,
...@@ -75,7 +75,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R( ...@@ -75,7 +75,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R(
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R(
LGBM_SE data, SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
LGBM_SE parameters, LGBM_SE parameters,
...@@ -94,7 +94,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R( ...@@ -94,7 +94,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE used_row_indices, SEXP used_row_indices,
SEXP len_used_row_indices, SEXP len_used_row_indices,
LGBM_SE parameters, LGBM_SE parameters,
LGBM_SE out LGBM_SE out
...@@ -120,7 +120,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R( ...@@ -120,7 +120,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R(
LGBM_SE handle, LGBM_SE handle,
SEXP buf_len, SEXP buf_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE feature_names LGBM_SE feature_names
); );
...@@ -157,7 +157,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetFree_R( ...@@ -157,7 +157,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetFree_R(
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE field_name, LGBM_SE field_name,
LGBM_SE field_data, SEXP field_data,
SEXP num_element SEXP num_element
); );
...@@ -171,7 +171,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R( ...@@ -171,7 +171,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R(
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE field_name, LGBM_SE field_name,
LGBM_SE out SEXP out
); );
/*! /*!
...@@ -184,7 +184,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R( ...@@ -184,7 +184,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R(
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetField_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetField_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE field_name, LGBM_SE field_name,
LGBM_SE field_data SEXP field_data
); );
/*! /*!
...@@ -206,7 +206,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetUpdateParamChecking_R( ...@@ -206,7 +206,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetUpdateParamChecking_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE out SEXP out
); );
/*! /*!
...@@ -217,7 +217,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R( ...@@ -217,7 +217,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumFeature_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumFeature_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE out SEXP out
); );
// --- start Booster interfaces // --- start Booster interfaces
...@@ -318,7 +318,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetParameter_R( ...@@ -318,7 +318,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetParameter_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE out SEXP out
); );
/*! /*!
...@@ -341,8 +341,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIter_R( ...@@ -341,8 +341,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIter_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIterCustom_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIterCustom_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE grad, SEXP grad,
LGBM_SE hess, SEXP hess,
SEXP len SEXP len
); );
...@@ -362,7 +362,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterRollbackOneIter_R( ...@@ -362,7 +362,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterRollbackOneIter_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE out SEXP out
); );
/*! /*!
...@@ -373,7 +373,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R( ...@@ -373,7 +373,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE out_result SEXP out_result
); );
/*! /*!
...@@ -384,7 +384,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R( ...@@ -384,7 +384,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE out_result SEXP out_result
); );
/*! /*!
...@@ -395,7 +395,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R( ...@@ -395,7 +395,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R(
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R(
LGBM_SE handle, LGBM_SE handle,
SEXP buf_len, SEXP buf_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE eval_names LGBM_SE eval_names
); );
...@@ -409,7 +409,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R( ...@@ -409,7 +409,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R(
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R(
LGBM_SE handle, LGBM_SE handle,
SEXP data_idx, SEXP data_idx,
LGBM_SE out_result SEXP out_result
); );
/*! /*!
...@@ -422,7 +422,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R( ...@@ -422,7 +422,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R(
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R(
LGBM_SE handle, LGBM_SE handle,
SEXP data_idx, SEXP data_idx,
LGBM_SE out SEXP out
); );
/*! /*!
...@@ -436,7 +436,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R( ...@@ -436,7 +436,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R(
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetPredict_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetPredict_R(
LGBM_SE handle, LGBM_SE handle,
SEXP data_idx, SEXP data_idx,
LGBM_SE out_result SEXP out_result
); );
/*! /*!
...@@ -481,7 +481,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R( ...@@ -481,7 +481,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R(
SEXP is_predcontrib, SEXP is_predcontrib,
SEXP start_iteration, SEXP start_iteration,
SEXP num_iteration, SEXP num_iteration,
LGBM_SE out_len SEXP out_len
); );
/*! /*!
...@@ -504,9 +504,9 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R( ...@@ -504,9 +504,9 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE indptr, SEXP indptr,
LGBM_SE indices, SEXP indices,
LGBM_SE data, SEXP data,
SEXP num_indptr, SEXP num_indptr,
SEXP nelem, SEXP nelem,
SEXP num_row, SEXP num_row,
...@@ -516,7 +516,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R( ...@@ -516,7 +516,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R(
SEXP start_iteration, SEXP start_iteration,
SEXP num_iteration, SEXP num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result SEXP out_result
); );
/*! /*!
...@@ -536,7 +536,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R( ...@@ -536,7 +536,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R(
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE data, SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
SEXP is_rawscore, SEXP is_rawscore,
...@@ -545,7 +545,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R( ...@@ -545,7 +545,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
SEXP start_iteration, SEXP start_iteration,
SEXP num_iteration, SEXP num_iteration,
LGBM_SE parameter, LGBM_SE parameter,
LGBM_SE out_result SEXP out_result
); );
/*! /*!
...@@ -574,7 +574,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R( ...@@ -574,7 +574,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP buffer_len, SEXP buffer_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE out_str LGBM_SE out_str
); );
...@@ -590,7 +590,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R( ...@@ -590,7 +590,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP buffer_len, SEXP buffer_len,
LGBM_SE actual_len, SEXP actual_len,
LGBM_SE out_str LGBM_SE out_str
); );
......
...@@ -1247,6 +1247,38 @@ test_that("lgb.train() supports non-ASCII feature names", { ...@@ -1247,6 +1247,38 @@ test_that("lgb.train() supports non-ASCII feature names", {
} }
}) })
test_that("lgb.train() works with integer, double, and numeric data", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
y <- mtcars[, 1L, drop = TRUE]
expected_mae <- 4.263667
for (data_mode in c("numeric", "double", "integer")) {
mode(X) <- data_mode
nrounds <- 10L
bst <- lightgbm(
data = X
, label = y
, params = list(
objective = "regression"
, min_data = 1L
, learning_rate = 0.01
, seed = 708L
)
, nrounds = nrounds
)
# should have trained for 10 iterations and found splits
modelDT <- lgb.model.dt.tree(bst)
expect_equal(modelDT[, max(tree_index)], nrounds - 1L)
expect_gt(nrow(modelDT), nrounds * 3L)
# should have achieved expected performance
preds <- predict(bst, X)
mae <- mean(abs(y - preds))
expect_true(abs(mae - expected_mae) < TOLERANCE)
}
})
test_that("when early stopping is not activated, best_iter and best_score come from valids and not training data", { test_that("when early stopping is not activated, best_iter and best_score come from valids and not training data", {
set.seed(708L) set.seed(708L)
trainDF <- data.frame( trainDF <- data.frame(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment