Unverified Commit 676c95fb authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] manage Dataset and Booster handles as R external pointers (fixes #3016) (#4265)



* started converting handles

* more changes

* sort of working for Dataset

* yay all the tests are passing for Dataset handle changes

* working for other handle types

* remove debugging logging

* remove unnecessary spaces

* fix null logic

* more NULL

* updates

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* consolidate steps
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 0a172d9e
...@@ -37,7 +37,7 @@ Booster <- R6::R6Class( ...@@ -37,7 +37,7 @@ Booster <- R6::R6Class(
# Create parameters and handle # Create parameters and handle
params <- append(params, list(...)) params <- append(params, list(...))
handle <- lgb.null.handle() handle <- NULL
# Attempts to create a handle for the dataset # Attempts to create a handle for the dataset
try({ try({
...@@ -52,11 +52,10 @@ Booster <- R6::R6Class( ...@@ -52,11 +52,10 @@ Booster <- R6::R6Class(
params <- modifyList(params, train_set$get_params()) params <- modifyList(params, train_set$get_params())
params_str <- lgb.params2str(params = params) params_str <- lgb.params2str(params = params)
# Store booster handle # Store booster handle
.Call( handle <- .Call(
LGBM_BoosterCreate_R LGBM_BoosterCreate_R
, train_set_handle , train_set_handle
, params_str , params_str
, handle
) )
# Create private booster information # Create private booster information
...@@ -88,10 +87,9 @@ Booster <- R6::R6Class( ...@@ -88,10 +87,9 @@ Booster <- R6::R6Class(
} }
# Create booster from model # Create booster from model
.Call( handle <- .Call(
LGBM_BoosterCreateFromModelfile_R LGBM_BoosterCreateFromModelfile_R
, modelfile , modelfile
, handle
) )
} else if (!is.null(model_str)) { } else if (!is.null(model_str)) {
...@@ -102,10 +100,9 @@ Booster <- R6::R6Class( ...@@ -102,10 +100,9 @@ Booster <- R6::R6Class(
} }
# Create booster from model # Create booster from model
.Call( handle <- .Call(
LGBM_BoosterLoadModelFromString_R LGBM_BoosterLoadModelFromString_R
, model_str , model_str
, handle
) )
} else { } else {
......
...@@ -192,7 +192,6 @@ Dataset <- R6::R6Class( ...@@ -192,7 +192,6 @@ Dataset <- R6::R6Class(
if (!is.null(private$reference)) { if (!is.null(private$reference)) {
ref_handle <- private$reference$.__enclos_env__$private$get_handle() ref_handle <- private$reference$.__enclos_env__$private$get_handle()
} }
handle <- lgb.null.handle()
# Not subsetting # Not subsetting
if (is.null(private$used_indices)) { if (is.null(private$used_indices)) {
...@@ -200,25 +199,23 @@ Dataset <- R6::R6Class( ...@@ -200,25 +199,23 @@ Dataset <- R6::R6Class(
# Are we using a data file? # Are we using a data file?
if (is.character(private$raw_data)) { if (is.character(private$raw_data)) {
.Call( handle <- .Call(
LGBM_DatasetCreateFromFile_R LGBM_DatasetCreateFromFile_R
, private$raw_data , private$raw_data
, params_str , params_str
, ref_handle , ref_handle
, handle
) )
} else if (is.matrix(private$raw_data)) { } else if (is.matrix(private$raw_data)) {
# Are we using a matrix? # Are we using a matrix?
.Call( handle <- .Call(
LGBM_DatasetCreateFromMat_R LGBM_DatasetCreateFromMat_R
, private$raw_data , private$raw_data
, nrow(private$raw_data) , nrow(private$raw_data)
, ncol(private$raw_data) , ncol(private$raw_data)
, params_str , params_str
, ref_handle , ref_handle
, handle
) )
} else if (methods::is(private$raw_data, "dgCMatrix")) { } else if (methods::is(private$raw_data, "dgCMatrix")) {
...@@ -226,7 +223,7 @@ Dataset <- R6::R6Class( ...@@ -226,7 +223,7 @@ Dataset <- R6::R6Class(
stop("Cannot support large CSC matrix") stop("Cannot support large CSC matrix")
} }
# Are we using a dgCMatrix (sparsed matrix column compressed) # Are we using a dgCMatrix (sparsed matrix column compressed)
.Call( handle <- .Call(
LGBM_DatasetCreateFromCSC_R LGBM_DatasetCreateFromCSC_R
, private$raw_data@p , private$raw_data@p
, private$raw_data@i , private$raw_data@i
...@@ -236,7 +233,6 @@ Dataset <- R6::R6Class( ...@@ -236,7 +233,6 @@ Dataset <- R6::R6Class(
, nrow(private$raw_data) , nrow(private$raw_data)
, params_str , params_str
, ref_handle , ref_handle
, handle
) )
} else { } else {
...@@ -257,13 +253,12 @@ Dataset <- R6::R6Class( ...@@ -257,13 +253,12 @@ Dataset <- R6::R6Class(
} }
# Construct subset # Construct subset
.Call( handle <- .Call(
LGBM_DatasetGetSubset_R LGBM_DatasetGetSubset_R
, ref_handle , ref_handle
, c(private$used_indices) # Adding c() fixes issue in R v3.5 , c(private$used_indices) # Adding c() fixes issue in R v3.5
, length(private$used_indices) , length(private$used_indices)
, params_str , params_str
, handle
) )
} }
......
...@@ -30,17 +30,15 @@ Predictor <- R6::R6Class( ...@@ -30,17 +30,15 @@ Predictor <- R6::R6Class(
initialize = function(modelfile, ...) { initialize = function(modelfile, ...) {
params <- list(...) params <- list(...)
private$params <- lgb.params2str(params = params) private$params <- lgb.params2str(params = params)
# Create new lgb handle handle <- NULL
handle <- lgb.null.handle()
# Check if handle is a character # Check if handle is a character
if (is.character(modelfile)) { if (is.character(modelfile)) {
# Create handle on it # Create handle on it
.Call( handle <- .Call(
LGBM_BoosterCreateFromModelfile_R LGBM_BoosterCreateFromModelfile_R
, modelfile , modelfile
, handle
) )
private$need_free_handle <- TRUE private$need_free_handle <- TRUE
......
...@@ -6,16 +6,13 @@ lgb.is.Dataset <- function(x) { ...@@ -6,16 +6,13 @@ lgb.is.Dataset <- function(x) {
return(lgb.check.r6.class(object = x, name = "lgb.Dataset")) return(lgb.check.r6.class(object = x, name = "lgb.Dataset"))
} }
lgb.null.handle <- function() {
if (.Machine$sizeof.pointer == 8L) {
return(NA_real_)
} else {
return(NA_integer_)
}
}
lgb.is.null.handle <- function(x) { lgb.is.null.handle <- function(x) {
return(is.null(x) || is.na(x)) if (is.null(x)) {
return(TRUE)
}
return(
isTRUE(.Call(LGBM_HandleIsNull_R, x))
)
} }
# [description] Get the most recent error stored on the C++ side and raise it # [description] Get the most recent error stored on the C++ side and raise it
......
/*!
* Copyright (c) 2017 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*
* \brief A simple wrapper for accessing data in R object.
*
* \note
* We previously did not want to use R's headers because of license concerns. This is no longer a concern:
* https://github.com/microsoft/LightGBM/issues/629#issuecomment-474995635
* For now, this wrapper is LightGBM's interface from R to C.
* If R changes the way it defines objects, this file will need to be updated as well.
*/
#ifndef R_OBJECT_HELPER_H_
#define R_OBJECT_HELPER_H_
#include <cstdint>
#define NAMED_BITS 16
struct lgbm_sxpinfo {
unsigned int type : 5;
unsigned int scalar : 1;
unsigned int obj : 1;
unsigned int alt : 1;
unsigned int gp : 16;
unsigned int mark : 1;
unsigned int debug : 1;
unsigned int trace : 1;
unsigned int spare : 1;
unsigned int gcgen : 1;
unsigned int gccls : 3;
unsigned int named : NAMED_BITS;
unsigned int extra : 32 - NAMED_BITS;
};
struct lgbm_primsxp {
int offset;
};
struct lgbm_symsxp {
struct LGBM_SER *pname;
struct LGBM_SER *value;
struct LGBM_SER *internal;
};
struct lgbm_listsxp {
struct LGBM_SER *carval;
struct LGBM_SER *cdrval;
struct LGBM_SER *tagval;
};
struct lgbm_envsxp {
struct LGBM_SER *frame;
struct LGBM_SER *enclos;
struct LGBM_SER *hashtab;
};
struct lgbm_closxp {
struct LGBM_SER *formals;
struct LGBM_SER *body;
struct LGBM_SER *env;
};
struct lgbm_promsxp {
struct LGBM_SER *value;
struct LGBM_SER *expr;
struct LGBM_SER *env;
};
typedef struct LGBM_SER {
struct lgbm_sxpinfo sxpinfo;
struct LGBM_SER* attrib;
struct LGBM_SER* gengc_next_node, *gengc_prev_node;
union {
struct lgbm_primsxp primsxp;
struct lgbm_symsxp symsxp;
struct lgbm_listsxp listsxp;
struct lgbm_envsxp envsxp;
struct lgbm_closxp closxp;
struct lgbm_promsxp promsxp;
} u;
} LGBM_SER, *LGBM_SE;
struct lgbm_vecsxp {
R_xlen_t length;
R_xlen_t truelength;
};
typedef struct VECTOR_SER {
struct lgbm_sxpinfo sxpinfo;
struct LGBM_SER* attrib;
struct LGBM_SER* gengc_next_node, *gengc_prev_node;
struct lgbm_vecsxp vecsxp;
} VECTOR_SER, *VECSE;
typedef union { VECTOR_SER s; double align; } SEXPREC_ALIGN;
#define DATAPTR(x) ((reinterpret_cast<SEXPREC_ALIGN*>(x)) + 1)
#define R_IS_NULL(x) ((*reinterpret_cast<LGBM_SE>(x)).sxpinfo.type == 0)
// 64bit pointer
#if INTPTR_MAX == INT64_MAX
#define R_ADDR(x) (reinterpret_cast<int64_t*> DATAPTR(x))
inline void R_SET_PTR(LGBM_SE x, void* ptr) {
if (ptr == nullptr) {
R_ADDR(x)[0] = (int64_t)(NULL);
} else {
R_ADDR(x)[0] = (int64_t)(ptr);
}
}
inline void* R_GET_PTR(LGBM_SE x) {
if (R_IS_NULL(x)) {
return nullptr;
} else {
auto ret = reinterpret_cast<void*>(R_ADDR(x)[0]);
if (ret == NULL) {
ret = nullptr;
}
return ret;
}
}
#else
#define R_ADDR(x) (reinterpret_cast<int32_t*> DATAPTR(x))
inline void R_SET_PTR(LGBM_SE x, void* ptr) {
if (ptr == nullptr) {
R_ADDR(x)[0] = (int32_t)(NULL);
} else {
R_ADDR(x)[0] = (int32_t)(ptr);
}
}
inline void* R_GET_PTR(LGBM_SE x) {
if (R_IS_NULL(x)) {
return nullptr;
} else {
auto ret = reinterpret_cast<void*>(R_ADDR(x)[0]);
if (ret == NULL) {
ret = nullptr;
}
return ret;
}
}
#endif // INTPTR_MAX == INT64_MAX
#endif // R_OBJECT_HELPER_H_
...@@ -50,15 +50,25 @@ SEXP LGBM_GetLastError_R() { ...@@ -50,15 +50,25 @@ SEXP LGBM_GetLastError_R() {
return out; return out;
} }
SEXP LGBM_HandleIsNull_R(SEXP handle) {
return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters, SEXP parameters,
LGBM_SE reference, SEXP reference) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference);
}
CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)), CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)),
R_GET_PTR(reference), &handle)); ref, &handle));
R_SET_PTR(out, handle); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
...@@ -69,8 +79,8 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, ...@@ -69,8 +79,8 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
SEXP nelem, SEXP nelem,
SEXP num_row, SEXP num_row,
SEXP parameters, SEXP parameters,
LGBM_SE reference, SEXP reference) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
const int* p_indptr = INTEGER(indptr); const int* p_indptr = INTEGER(indptr);
const int* p_indices = INTEGER(indices); const int* p_indices = INTEGER(indices);
...@@ -80,10 +90,16 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, ...@@ -80,10 +90,16 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem)); int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row)); int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference);
}
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices, CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, CHAR(Rf_asChar(parameters)), R_GET_PTR(reference), &handle)); nrow, CHAR(Rf_asChar(parameters)), ref, &handle));
R_SET_PTR(out, handle); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
...@@ -91,24 +107,30 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data, ...@@ -91,24 +107,30 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
SEXP parameters, SEXP parameters,
LGBM_SE reference, SEXP reference) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row)); int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col)); int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
double* p_mat = REAL(data); double* p_mat = REAL(data);
DatasetHandle handle = nullptr; DatasetHandle handle = nullptr;
DatasetHandle ref = nullptr;
if (!Rf_isNull(reference)) {
ref = R_ExternalPtrAddr(reference);
}
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
CHAR(Rf_asChar(parameters)), R_GET_PTR(reference), &handle)); CHAR(Rf_asChar(parameters)), ref, &handle));
R_SET_PTR(out, handle); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle, SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP used_row_indices, SEXP used_row_indices,
SEXP len_used_row_indices, SEXP len_used_row_indices,
SEXP parameters, SEXP parameters) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
int len = Rf_asInteger(len_used_row_indices); int len = Rf_asInteger(len_used_row_indices);
std::vector<int> idxvec(len); std::vector<int> idxvec(len);
...@@ -118,14 +140,16 @@ SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle, ...@@ -118,14 +140,16 @@ SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
idxvec[i] = INTEGER(used_row_indices)[i] - 1; idxvec[i] = INTEGER(used_row_indices)[i] - 1;
} }
DatasetHandle res = nullptr; DatasetHandle res = nullptr;
CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle), CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
idxvec.data(), len, CHAR(Rf_asChar(parameters)), idxvec.data(), len, CHAR(Rf_asChar(parameters)),
&res)); &res));
R_SET_PTR(out, res); ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle, SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) { SEXP feature_names) {
R_API_BEGIN(); R_API_BEGIN();
auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t'); auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
...@@ -134,16 +158,16 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle, ...@@ -134,16 +158,16 @@ SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str()); vec_sptr.push_back(vec_names[i].c_str());
} }
CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle), CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
vec_sptr.data(), len)); vec_sptr.data(), len));
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) { SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP feature_names; SEXP feature_names;
R_API_BEGIN(); R_API_BEGIN();
int len = 0; int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
const size_t reserved_string_size = 256; const size_t reserved_string_size = 256;
std::vector<std::vector<char>> names(len); std::vector<std::vector<char>> names(len);
std::vector<char*> ptr_names(len); std::vector<char*> ptr_names(len);
...@@ -155,7 +179,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) { ...@@ -155,7 +179,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
size_t required_string_size; size_t required_string_size;
CHECK_CALL( CHECK_CALL(
LGBM_DatasetGetFeatureNames( LGBM_DatasetGetFeatureNames(
R_GET_PTR(handle), R_ExternalPtrAddr(handle),
len, &out_len, len, &out_len,
reserved_string_size, &required_string_size, reserved_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
...@@ -168,7 +192,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) { ...@@ -168,7 +192,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
} }
CHECK_CALL( CHECK_CALL(
LGBM_DatasetGetFeatureNames( LGBM_DatasetGetFeatureNames(
R_GET_PTR(handle), R_ExternalPtrAddr(handle),
len, len,
&out_len, &out_len,
required_string_size, required_string_size,
...@@ -185,24 +209,24 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) { ...@@ -185,24 +209,24 @@ SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetSaveBinary_R(LGBM_SE handle, SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) { SEXP filename) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle), CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
CHAR(Rf_asChar(filename)))); CHAR(Rf_asChar(filename))));
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetFree_R(LGBM_SE handle) { SEXP LGBM_DatasetFree_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
if (R_GET_PTR(handle) != nullptr) { if (R_ExternalPtrAddr(handle)) {
CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle))); CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
R_SET_PTR(handle, nullptr); R_ClearExternalPtr(handle);
} }
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetSetField_R(LGBM_SE handle, SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data, SEXP field_data,
SEXP num_element) { SEXP num_element) {
...@@ -215,21 +239,21 @@ SEXP LGBM_DatasetSetField_R(LGBM_SE handle, ...@@ -215,21 +239,21 @@ SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]); vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
} }
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32)); CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_INT32));
} else if (!strcmp("init_score", name)) { } else if (!strcmp("init_score", name)) {
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64)); CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
} else { } else {
std::vector<float> vec(len); std::vector<float> vec(len);
#pragma omp parallel for schedule(static, 512) if (len >= 1024) #pragma omp parallel for schedule(static, 512) if (len >= 1024)
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec[i] = static_cast<float>(REAL(field_data)[i]); vec[i] = static_cast<float>(REAL(field_data)[i]);
} }
CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32)); CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
} }
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetField_R(LGBM_SE handle, SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data) { SEXP field_data) {
R_API_BEGIN(); R_API_BEGIN();
...@@ -237,7 +261,7 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle, ...@@ -237,7 +261,7 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
int out_len = 0; int out_len = 0;
int out_type = 0; int out_type = 0;
const void* res; const void* res;
CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type)); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
auto p_data = reinterpret_cast<const int32_t*>(res); auto p_data = reinterpret_cast<const int32_t*>(res);
...@@ -262,7 +286,7 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle, ...@@ -262,7 +286,7 @@ SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle, SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name, SEXP field_name,
SEXP out) { SEXP out) {
R_API_BEGIN(); R_API_BEGIN();
...@@ -270,7 +294,7 @@ SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle, ...@@ -270,7 +294,7 @@ SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
int out_len = 0; int out_len = 0;
int out_type = 0; int out_type = 0;
const void* res; const void* res;
CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type)); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
if (!strcmp("group", name) || !strcmp("query", name)) { if (!strcmp("group", name) || !strcmp("query", name)) {
out_len -= 1; out_len -= 1;
} }
...@@ -285,109 +309,115 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params, ...@@ -285,109 +309,115 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, SEXP out) { SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
int nrow; int nrow;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow)); CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = static_cast<int>(nrow); INTEGER(out)[0] = static_cast<int>(nrow);
R_API_END(); R_API_END();
} }
SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle, SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
SEXP out) { SEXP out) {
int nfeature; int nfeature;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature)); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = static_cast<int>(nfeature); INTEGER(out)[0] = static_cast<int>(nfeature);
R_API_END(); R_API_END();
} }
// --- start Booster interfaces // --- start Booster interfaces
SEXP LGBM_BoosterFree_R(LGBM_SE handle) { SEXP LGBM_BoosterFree_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
if (R_GET_PTR(handle) != nullptr) { if (R_ExternalPtrAddr(handle)) {
CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle))); CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
R_SET_PTR(handle, nullptr); R_ClearExternalPtr(handle);
} }
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterCreate_R(LGBM_SE train_data, SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters, SEXP parameters) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), CHAR(Rf_asChar(parameters)), &handle)); CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle));
R_SET_PTR(out, handle); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename, SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
int out_num_iterations = 0; int out_num_iterations = 0;
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle));
R_SET_PTR(out, handle); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str, SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
LGBM_SE out) { SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
int out_num_iterations = 0; int out_num_iterations = 0;
BoosterHandle handle = nullptr; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle));
R_SET_PTR(out, handle); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
UNPROTECT(1);
return ret;
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterMerge_R(LGBM_SE handle, SEXP LGBM_BoosterMerge_R(SEXP handle,
LGBM_SE other_handle) { SEXP other_handle) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle))); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterAddValidData_R(LGBM_SE handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle,
LGBM_SE valid_data) { SEXP valid_data) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data))); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterResetTrainingData_R(LGBM_SE handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
LGBM_SE train_data) { SEXP train_data) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data))); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle, SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) { SEXP parameters) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), CHAR(Rf_asChar(parameters)))); CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(parameters))));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle, SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
SEXP out) { SEXP out) {
int num_class; int num_class;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class)); CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = static_cast<int>(num_class); INTEGER(out)[0] = static_cast<int>(num_class);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) { SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
int is_finished = 0; int is_finished = 0;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle, SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
SEXP grad, SEXP grad,
SEXP hess, SEXP hess,
SEXP len) { SEXP len) {
...@@ -400,46 +430,46 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle, ...@@ -400,46 +430,46 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
tgrad[j] = static_cast<float>(REAL(grad)[j]); tgrad[j] = static_cast<float>(REAL(grad)[j]);
thess[j] = static_cast<float>(REAL(hess)[j]); thess[j] = static_cast<float>(REAL(hess)[j]);
} }
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished)); CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) { SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle))); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle, SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP out) { SEXP out) {
int out_iteration; int out_iteration;
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration)); CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = static_cast<int>(out_iteration); INTEGER(out)[0] = static_cast<int>(out_iteration);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle, SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle, SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret)); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) { SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP eval_names; SEXP eval_names;
R_API_BEGIN(); R_API_BEGIN();
int len; int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len)); CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
const size_t reserved_string_size = 128; const size_t reserved_string_size = 128;
std::vector<std::vector<char>> names(len); std::vector<std::vector<char>> names(len);
...@@ -453,7 +483,7 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) { ...@@ -453,7 +483,7 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
size_t required_string_size; size_t required_string_size;
CHECK_CALL( CHECK_CALL(
LGBM_BoosterGetEvalNames( LGBM_BoosterGetEvalNames(
R_GET_PTR(handle), R_ExternalPtrAddr(handle),
len, &out_len, len, &out_len,
reserved_string_size, &required_string_size, reserved_string_size, &required_string_size,
ptr_names.data())); ptr_names.data()));
...@@ -466,7 +496,7 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) { ...@@ -466,7 +496,7 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
} }
CHECK_CALL( CHECK_CALL(
LGBM_BoosterGetEvalNames( LGBM_BoosterGetEvalNames(
R_GET_PTR(handle), R_ExternalPtrAddr(handle),
len, len,
&out_len, &out_len,
required_string_size, required_string_size,
...@@ -483,36 +513,36 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) { ...@@ -483,36 +513,36 @@ SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetEval_R(LGBM_SE handle, SEXP LGBM_BoosterGetEval_R(SEXP handle,
SEXP data_idx, SEXP data_idx,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
int len; int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len)); CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int out_len; int out_len;
CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
CHECK_EQ(out_len, len); CHECK_EQ(out_len, len);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle, SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
SEXP data_idx, SEXP data_idx,
SEXP out) { SEXP out) {
R_API_BEGIN(); R_API_BEGIN();
int64_t len; int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len)); CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len); INTEGER(out)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle, SEXP LGBM_BoosterGetPredict_R(SEXP handle,
SEXP data_idx, SEXP data_idx,
SEXP out_result) { SEXP out_result) {
R_API_BEGIN(); R_API_BEGIN();
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
...@@ -530,7 +560,7 @@ int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) { ...@@ -530,7 +560,7 @@ int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
return pred_type; return pred_type;
} }
SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle, SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP data_filename, SEXP data_filename,
SEXP data_has_header, SEXP data_has_header,
SEXP is_rawscore, SEXP is_rawscore,
...@@ -542,13 +572,13 @@ SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle, ...@@ -542,13 +572,13 @@ SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle,
SEXP result_filename) { SEXP result_filename) {
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), CHAR(Rf_asChar(data_filename)), CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(data_filename)),
Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)),
CHAR(Rf_asChar(result_filename)))); CHAR(Rf_asChar(result_filename))));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
SEXP num_row, SEXP num_row,
SEXP is_rawscore, SEXP is_rawscore,
SEXP is_leafidx, SEXP is_leafidx,
...@@ -559,13 +589,13 @@ SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle, ...@@ -559,13 +589,13 @@ SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
R_API_BEGIN(); R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0; int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row), CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
INTEGER(out_len)[0] = static_cast<int>(len); INTEGER(out_len)[0] = static_cast<int>(len);
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle, SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP indptr, SEXP indptr,
SEXP indices, SEXP indices,
SEXP data, SEXP data,
...@@ -591,14 +621,14 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle, ...@@ -591,14 +621,14 @@ SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
int64_t nrow = Rf_asInteger(num_row); int64_t nrow = Rf_asInteger(num_row);
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
p_indptr, C_API_DTYPE_INT32, p_indices, p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret)); nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle, SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP data, SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
...@@ -618,23 +648,23 @@ SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -618,23 +648,23 @@ SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
const double* p_mat = REAL(data); const double* p_mat = REAL(data);
double* ptr_ret = REAL(out_result); double* ptr_ret = REAL(out_result);
int64_t out_len; int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle), CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret)); pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle, SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP filename) { SEXP filename) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename)))); CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename))));
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle, SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type) {
SEXP model_str; SEXP model_str;
...@@ -644,11 +674,11 @@ SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle, ...@@ -644,11 +674,11 @@ SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
int64_t num_iter = Rf_asInteger(num_iteration); int64_t num_iter = Rf_asInteger(num_iteration);
int64_t importance_type = Rf_asInteger(feature_importance_type); int64_t importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
} }
model_str = PROTECT(Rf_allocVector(STRSXP, 1)); model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
...@@ -657,7 +687,7 @@ SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle, ...@@ -657,7 +687,7 @@ SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle, SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type) { SEXP feature_importance_type) {
SEXP model_str; SEXP model_str;
...@@ -667,11 +697,11 @@ SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle, ...@@ -667,11 +697,11 @@ SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
int64_t num_iter = Rf_asInteger(num_iteration); int64_t num_iter = Rf_asInteger(num_iteration);
int64_t importance_type = Rf_asInteger(feature_importance_type); int64_t importance_type = Rf_asInteger(feature_importance_type);
std::vector<char> inner_char_buf(buf_len); std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
// if the model string was larger than the initial buffer, allocate a bigger buffer and try again // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) { if (out_len > buf_len) {
inner_char_buf.resize(out_len); inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
} }
model_str = PROTECT(Rf_allocVector(STRSXP, 1)); model_str = PROTECT(Rf_allocVector(STRSXP, 1));
SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
...@@ -683,10 +713,11 @@ SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle, ...@@ -683,10 +713,11 @@ SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
// .Call() calls // .Call() calls
static const R_CallMethodDef CallEntries[] = { static const R_CallMethodDef CallEntries[] = {
{"LGBM_GetLastError_R" , (DL_FUNC) &LGBM_GetLastError_R , 0}, {"LGBM_GetLastError_R" , (DL_FUNC) &LGBM_GetLastError_R , 0},
{"LGBM_DatasetCreateFromFile_R" , (DL_FUNC) &LGBM_DatasetCreateFromFile_R , 4}, {"LGBM_HandleIsNull_R" , (DL_FUNC) &LGBM_HandleIsNull_R , 1},
{"LGBM_DatasetCreateFromCSC_R" , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R , 9}, {"LGBM_DatasetCreateFromFile_R" , (DL_FUNC) &LGBM_DatasetCreateFromFile_R , 3},
{"LGBM_DatasetCreateFromMat_R" , (DL_FUNC) &LGBM_DatasetCreateFromMat_R , 6}, {"LGBM_DatasetCreateFromCSC_R" , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R , 8},
{"LGBM_DatasetGetSubset_R" , (DL_FUNC) &LGBM_DatasetGetSubset_R , 5}, {"LGBM_DatasetCreateFromMat_R" , (DL_FUNC) &LGBM_DatasetCreateFromMat_R , 5},
{"LGBM_DatasetGetSubset_R" , (DL_FUNC) &LGBM_DatasetGetSubset_R , 4},
{"LGBM_DatasetSetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R , 2}, {"LGBM_DatasetSetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R , 2},
{"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 1}, {"LGBM_DatasetGetFeatureNames_R" , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R , 1},
{"LGBM_DatasetSaveBinary_R" , (DL_FUNC) &LGBM_DatasetSaveBinary_R , 2}, {"LGBM_DatasetSaveBinary_R" , (DL_FUNC) &LGBM_DatasetSaveBinary_R , 2},
...@@ -697,10 +728,10 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -697,10 +728,10 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2}, {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
{"LGBM_DatasetGetNumData_R" , (DL_FUNC) &LGBM_DatasetGetNumData_R , 2}, {"LGBM_DatasetGetNumData_R" , (DL_FUNC) &LGBM_DatasetGetNumData_R , 2},
{"LGBM_DatasetGetNumFeature_R" , (DL_FUNC) &LGBM_DatasetGetNumFeature_R , 2}, {"LGBM_DatasetGetNumFeature_R" , (DL_FUNC) &LGBM_DatasetGetNumFeature_R , 2},
{"LGBM_BoosterCreate_R" , (DL_FUNC) &LGBM_BoosterCreate_R , 3}, {"LGBM_BoosterCreate_R" , (DL_FUNC) &LGBM_BoosterCreate_R , 2},
{"LGBM_BoosterFree_R" , (DL_FUNC) &LGBM_BoosterFree_R , 1}, {"LGBM_BoosterFree_R" , (DL_FUNC) &LGBM_BoosterFree_R , 1},
{"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 2}, {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 1},
{"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 2}, {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 1},
{"LGBM_BoosterMerge_R" , (DL_FUNC) &LGBM_BoosterMerge_R , 2}, {"LGBM_BoosterMerge_R" , (DL_FUNC) &LGBM_BoosterMerge_R , 2},
{"LGBM_BoosterAddValidData_R" , (DL_FUNC) &LGBM_BoosterAddValidData_R , 2}, {"LGBM_BoosterAddValidData_R" , (DL_FUNC) &LGBM_BoosterAddValidData_R , 2},
{"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2}, {"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2},
......
...@@ -11,14 +11,21 @@ ...@@ -11,14 +11,21 @@
#define R_USE_C99_IN_CXX #define R_USE_C99_IN_CXX
#include <Rinternals.h> #include <Rinternals.h>
#include "R_object_helper.h"
/*! /*!
* \brief get string message of the last error * \brief get string message of the last error
* \return err_msg string with error information * \return err_msg string with error information
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_GetLastError_R(); LIGHTGBM_C_EXPORT SEXP LGBM_GetLastError_R();
/*!
* \brief check if an R external pointer (like a Booster or Dataset handle) is a null pointer
* \param handle handle for a Booster, Dataset, or Predictor
* \return R logical, TRUE if the handle is a null pointer
*/
LIGHTGBM_C_EXPORT SEXP LGBM_HandleIsNull_R(
SEXP handle
);
// --- start Dataset interface // --- start Dataset interface
/*! /*!
...@@ -26,14 +33,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_GetLastError_R(); ...@@ -26,14 +33,12 @@ LIGHTGBM_C_EXPORT SEXP LGBM_GetLastError_R();
* \param filename the name of the file * \param filename the name of the file
* \param parameters additional parameters * \param parameters additional parameters
* \param reference used to align bin mapper with other Dataset, nullptr means not used * \param reference used to align bin mapper with other Dataset, nullptr means not used
* \param out created Dataset * \return Dataset handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromFile_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromFile_R(
SEXP filename, SEXP filename,
SEXP parameters, SEXP parameters,
LGBM_SE reference, SEXP reference
LGBM_SE out
); );
/*! /*!
...@@ -46,8 +51,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromFile_R( ...@@ -46,8 +51,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromFile_R(
* \param num_row number of rows * \param num_row number of rows
* \param parameters additional parameters * \param parameters additional parameters
* \param reference used to align bin mapper with other Dataset, nullptr means not used * \param reference used to align bin mapper with other Dataset, nullptr means not used
* \param out created Dataset * \return Dataset handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R(
SEXP indptr, SEXP indptr,
...@@ -57,8 +61,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R( ...@@ -57,8 +61,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R(
SEXP nelem, SEXP nelem,
SEXP num_row, SEXP num_row,
SEXP parameters, SEXP parameters,
LGBM_SE reference, SEXP reference
LGBM_SE out
); );
/*! /*!
...@@ -68,16 +71,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R( ...@@ -68,16 +71,14 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromCSC_R(
* \param num_col number columns * \param num_col number columns
* \param parameters additional parameters * \param parameters additional parameters
* \param reference used to align bin mapper with other Dataset, nullptr means not used * \param reference used to align bin mapper with other Dataset, nullptr means not used
* \param out created Dataset * \return Dataset handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R(
SEXP data, SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
SEXP parameters, SEXP parameters,
LGBM_SE reference, SEXP reference
LGBM_SE out
); );
/*! /*!
...@@ -86,15 +87,13 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R( ...@@ -86,15 +87,13 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetCreateFromMat_R(
* \param used_row_indices Indices used in subset * \param used_row_indices Indices used in subset
* \param len_used_row_indices length of Indices used in subset * \param len_used_row_indices length of Indices used in subset
* \param parameters additional parameters * \param parameters additional parameters
* \param out created Dataset * \return Dataset handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R(
LGBM_SE handle, SEXP handle,
SEXP used_row_indices, SEXP used_row_indices,
SEXP len_used_row_indices, SEXP len_used_row_indices,
SEXP parameters, SEXP parameters
LGBM_SE out
); );
/*! /*!
...@@ -104,7 +103,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R( ...@@ -104,7 +103,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetSubset_R(
* \return R character vector of feature names * \return R character vector of feature names
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
LGBM_SE handle, SEXP handle,
SEXP feature_names SEXP feature_names
); );
...@@ -114,7 +113,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R( ...@@ -114,7 +113,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetFeatureNames_R(
* \return an R character vector with feature names from the Dataset or NULL if no feature names * \return an R character vector with feature names from the Dataset or NULL if no feature names
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R(
LGBM_SE handle SEXP handle
); );
/*! /*!
...@@ -124,7 +123,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R( ...@@ -124,7 +123,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFeatureNames_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSaveBinary_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSaveBinary_R(
LGBM_SE handle, SEXP handle,
SEXP filename SEXP filename
); );
...@@ -134,7 +133,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSaveBinary_R( ...@@ -134,7 +133,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSaveBinary_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetFree_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetFree_R(
LGBM_SE handle SEXP handle
); );
/*! /*!
...@@ -148,7 +147,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetFree_R( ...@@ -148,7 +147,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetFree_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R(
LGBM_SE handle, SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data, SEXP field_data,
SEXP num_element SEXP num_element
...@@ -162,7 +161,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R( ...@@ -162,7 +161,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetSetField_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R(
LGBM_SE handle, SEXP handle,
SEXP field_name, SEXP field_name,
SEXP out SEXP out
); );
...@@ -175,7 +174,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R( ...@@ -175,7 +174,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetFieldSize_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetField_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetField_R(
LGBM_SE handle, SEXP handle,
SEXP field_name, SEXP field_name,
SEXP field_data SEXP field_data
); );
...@@ -199,7 +198,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetUpdateParamChecking_R( ...@@ -199,7 +198,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetUpdateParamChecking_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R(
LGBM_SE handle, SEXP handle,
SEXP out SEXP out
); );
...@@ -210,7 +209,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R( ...@@ -210,7 +209,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumData_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumFeature_R( LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumFeature_R(
LGBM_SE handle, SEXP handle,
SEXP out SEXP out
); );
...@@ -220,13 +219,11 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumFeature_R( ...@@ -220,13 +219,11 @@ LIGHTGBM_C_EXPORT SEXP LGBM_DatasetGetNumFeature_R(
* \brief create a new boosting learner * \brief create a new boosting learner
* \param train_data training Dataset * \param train_data training Dataset
* \param parameters format: 'key1=value1 key2=value2' * \param parameters format: 'key1=value1 key2=value2'
* \param out handle of created Booster * \return Booster handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCreate_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCreate_R(
LGBM_SE train_data, SEXP train_data,
SEXP parameters, SEXP parameters
LGBM_SE out
); );
/*! /*!
...@@ -235,29 +232,25 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCreate_R( ...@@ -235,29 +232,25 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCreate_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterFree_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterFree_R(
LGBM_SE handle SEXP handle
); );
/*! /*!
* \brief load an existing Booster from model file * \brief load an existing Booster from model file
* \param filename filename of model * \param filename filename of model
* \param out handle of created Booster * \return Booster handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCreateFromModelfile_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCreateFromModelfile_R(
SEXP filename, SEXP filename
LGBM_SE out
); );
/*! /*!
* \brief load an existing Booster from a string * \brief load an existing Booster from a string
* \param model_str string containing the model * \param model_str string containing the model
* \param out handle of created Booster * \return Booster handle
* \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R(
SEXP model_str, SEXP model_str
LGBM_SE out
); );
/*! /*!
...@@ -267,8 +260,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R( ...@@ -267,8 +260,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterMerge_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterMerge_R(
LGBM_SE handle, SEXP handle,
LGBM_SE other_handle SEXP other_handle
); );
/*! /*!
...@@ -278,8 +271,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterMerge_R( ...@@ -278,8 +271,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterMerge_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterAddValidData_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterAddValidData_R(
LGBM_SE handle, SEXP handle,
LGBM_SE valid_data SEXP valid_data
); );
/*! /*!
...@@ -289,8 +282,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterAddValidData_R( ...@@ -289,8 +282,8 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterAddValidData_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetTrainingData_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetTrainingData_R(
LGBM_SE handle, SEXP handle,
LGBM_SE train_data SEXP train_data
); );
/*! /*!
...@@ -300,7 +293,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetTrainingData_R( ...@@ -300,7 +293,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetTrainingData_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetParameter_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetParameter_R(
LGBM_SE handle, SEXP handle,
SEXP parameters SEXP parameters
); );
...@@ -311,7 +304,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetParameter_R( ...@@ -311,7 +304,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterResetParameter_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
LGBM_SE handle, SEXP handle,
SEXP out SEXP out
); );
...@@ -321,7 +314,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R( ...@@ -321,7 +314,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIter_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIter_R(
LGBM_SE handle SEXP handle
); );
/*! /*!
...@@ -334,7 +327,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIter_R( ...@@ -334,7 +327,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIter_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIterCustom_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIterCustom_R(
LGBM_SE handle, SEXP handle,
SEXP grad, SEXP grad,
SEXP hess, SEXP hess,
SEXP len SEXP len
...@@ -346,7 +339,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIterCustom_R( ...@@ -346,7 +339,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterUpdateOneIterCustom_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterRollbackOneIter_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterRollbackOneIter_R(
LGBM_SE handle SEXP handle
); );
/*! /*!
...@@ -356,7 +349,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterRollbackOneIter_R( ...@@ -356,7 +349,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterRollbackOneIter_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
LGBM_SE handle, SEXP handle,
SEXP out SEXP out
); );
...@@ -367,7 +360,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R( ...@@ -367,7 +360,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R(
LGBM_SE handle, SEXP handle,
SEXP out_result SEXP out_result
); );
...@@ -378,7 +371,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R( ...@@ -378,7 +371,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetUpperBoundValue_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R(
LGBM_SE handle, SEXP handle,
SEXP out_result SEXP out_result
); );
...@@ -388,7 +381,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R( ...@@ -388,7 +381,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLowerBoundValue_R(
* \return R character vector with names of eval metrics * \return R character vector with names of eval metrics
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R(
LGBM_SE handle SEXP handle
); );
/*! /*!
...@@ -399,7 +392,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R( ...@@ -399,7 +392,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEvalNames_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R(
LGBM_SE handle, SEXP handle,
SEXP data_idx, SEXP data_idx,
SEXP out_result SEXP out_result
); );
...@@ -412,7 +405,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R( ...@@ -412,7 +405,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetEval_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R(
LGBM_SE handle, SEXP handle,
SEXP data_idx, SEXP data_idx,
SEXP out SEXP out
); );
...@@ -426,7 +419,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R( ...@@ -426,7 +419,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumPredict_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetPredict_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetPredict_R(
LGBM_SE handle, SEXP handle,
SEXP data_idx, SEXP data_idx,
SEXP out_result SEXP out_result
); );
...@@ -448,7 +441,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetPredict_R( ...@@ -448,7 +441,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetPredict_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForFile_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForFile_R(
LGBM_SE handle, SEXP handle,
SEXP data_filename, SEXP data_filename,
SEXP data_has_header, SEXP data_has_header,
SEXP is_rawscore, SEXP is_rawscore,
...@@ -475,7 +468,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForFile_R( ...@@ -475,7 +468,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForFile_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R(
LGBM_SE handle, SEXP handle,
SEXP num_row, SEXP num_row,
SEXP is_rawscore, SEXP is_rawscore,
SEXP is_leafidx, SEXP is_leafidx,
...@@ -509,7 +502,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R( ...@@ -509,7 +502,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterCalcNumPredict_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R(
LGBM_SE handle, SEXP handle,
SEXP indptr, SEXP indptr,
SEXP indices, SEXP indices,
SEXP data, SEXP data,
...@@ -546,7 +539,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R( ...@@ -546,7 +539,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForCSC_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
LGBM_SE handle, SEXP handle,
SEXP data, SEXP data,
SEXP num_row, SEXP num_row,
SEXP num_col, SEXP num_col,
...@@ -568,7 +561,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R( ...@@ -568,7 +561,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterPredictForMat_R(
* \return R NULL value * \return R NULL value
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
LGBM_SE handle, SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type, SEXP feature_importance_type,
SEXP filename SEXP filename
...@@ -582,7 +575,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R( ...@@ -582,7 +575,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModel_R(
* \return R character vector (length=1) with model string * \return R character vector (length=1) with model string
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
LGBM_SE handle, SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type SEXP feature_importance_type
); );
...@@ -595,7 +588,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R( ...@@ -595,7 +588,7 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterSaveModelToString_R(
* \return R character vector (length=1) with model JSON * \return R character vector (length=1) with model JSON
*/ */
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R( LIGHTGBM_C_EXPORT SEXP LGBM_BoosterDumpModel_R(
LGBM_SE handle, SEXP handle,
SEXP num_iteration, SEXP num_iteration,
SEXP feature_importance_type SEXP feature_importance_type
); );
......
...@@ -73,16 +73,14 @@ test_that("lgb.Dataset: nrow is correct for a very sparse matrix", { ...@@ -73,16 +73,14 @@ test_that("lgb.Dataset: nrow is correct for a very sparse matrix", {
test_that("lgb.Dataset: Dataset should be able to construct from matrix and return non-null handle", { test_that("lgb.Dataset: Dataset should be able to construct from matrix and return non-null handle", {
rawData <- matrix(runif(1000L), ncol = 10L) rawData <- matrix(runif(1000L), ncol = 10L)
handle <- lgb.null.handle()
ref_handle <- NULL ref_handle <- NULL
.Call( handle <- .Call(
LGBM_DatasetCreateFromMat_R LGBM_DatasetCreateFromMat_R
, rawData , rawData
, nrow(rawData) , nrow(rawData)
, ncol(rawData) , ncol(rawData)
, lightgbm:::lgb.params2str(params = list()) , lightgbm:::lgb.params2str(params = list())
, ref_handle , ref_handle
, handle
) )
expect_false(is.na(handle)) expect_false(is.na(handle))
.Call(LGBM_DatasetFree_R, handle) .Call(LGBM_DatasetFree_R, handle)
......
...@@ -324,7 +324,7 @@ result <- file.copy( ...@@ -324,7 +324,7 @@ result <- file.copy(
, overwrite = TRUE , overwrite = TRUE
) )
.handle_result(result) .handle_result(result)
for (src_file in c("lightgbm_R.cpp", "lightgbm_R.h", "R_object_helper.h")) { for (src_file in c("lightgbm_R.cpp", "lightgbm_R.h")) {
result <- file.copy( result <- file.copy(
from = file.path(TEMP_SOURCE_DIR, src_file) from = file.path(TEMP_SOURCE_DIR, src_file)
, to = file.path(TEMP_SOURCE_DIR, "src", src_file) , to = file.path(TEMP_SOURCE_DIR, "src", src_file)
...@@ -386,11 +386,6 @@ dynlib_line <- grep( ...@@ -386,11 +386,6 @@ dynlib_line <- grep(
c_api_contents <- readLines(file.path(TEMP_SOURCE_DIR, "src", "lightgbm_R.h")) c_api_contents <- readLines(file.path(TEMP_SOURCE_DIR, "src", "lightgbm_R.h"))
c_api_contents <- c_api_contents[grepl("^LIGHTGBM_C_EXPORT", c_api_contents)] c_api_contents <- c_api_contents[grepl("^LIGHTGBM_C_EXPORT", c_api_contents)]
c_api_contents <- gsub(
pattern = "LIGHTGBM_C_EXPORT LGBM_SE "
, replacement = ""
, x = c_api_contents
)
c_api_contents <- gsub( c_api_contents <- gsub(
pattern = "LIGHTGBM_C_EXPORT SEXP " pattern = "LIGHTGBM_C_EXPORT SEXP "
, replacement = "" , replacement = ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment