Commit 2367b463 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix R crash problem. (#784)

* fix r random crash problem.

* fix R error msg.
parent 01a408e7
...@@ -31,8 +31,7 @@ Booster <- R6Class( ...@@ -31,8 +31,7 @@ Booster <- R6Class(
# Create parameters and handle # Create parameters and handle
params <- append(params, list(...)) params <- append(params, list(...))
params_str <- lgb.params2str(params) params_str <- lgb.params2str(params)
handle <- lgb.new.handle() handle <- 0.0
# Check if training dataset is not null # Check if training dataset is not null
if (!is.null(train_set)) { if (!is.null(train_set)) {
...@@ -40,10 +39,10 @@ Booster <- R6Class( ...@@ -40,10 +39,10 @@ Booster <- R6Class(
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster: Can only use lgb.Dataset as training data") stop("lgb.Booster: Can only use lgb.Dataset as training data")
} }
# Store booster handle # Store booster handle
handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str) handle <- lgb.call("LGBM_BoosterCreate_R", ret = handle, train_set$.__enclos_env__$private$get_handle(), params_str)
# Create private booster information # Create private booster information
private$train_set <- train_set private$train_set <- train_set
private$num_dataset <- 1 private$num_dataset <- 1
...@@ -93,14 +92,17 @@ Booster <- R6Class( ...@@ -93,14 +92,17 @@ Booster <- R6Class(
stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance") stop("lgb.Booster: Need at least either training dataset, model file, or model_str to create booster instance")
} }
if (lgb.is.null.handle(handle)) {
# Create class stop("lgb.Booster: cannot create Booster handle")
class(handle) <- "lgb.Booster.handle" } else {
private$handle <- handle # Create class
private$num_class <- 1L class(handle) <- "lgb.Booster.handle"
private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R", private$handle <- handle
ret = private$num_class, private$num_class <- 1L
private$handle) private$num_class <- lgb.call("LGBM_BoosterGetNumClasses_R",
ret = private$num_class,
private$handle)
}
}, },
......
...@@ -167,7 +167,7 @@ Dataset <- R6Class( ...@@ -167,7 +167,7 @@ Dataset <- R6Class(
if (!is.null(private$reference)) { if (!is.null(private$reference)) {
ref_handle <- private$reference$.__enclos_env__$private$get_handle() ref_handle <- private$reference$.__enclos_env__$private$get_handle()
} }
handle <- lgb.new.handle() handle <- 0.0
# Not subsetting # Not subsetting
if (is.null(private$used_indices)) { if (is.null(private$used_indices)) {
...@@ -229,7 +229,9 @@ Dataset <- R6Class( ...@@ -229,7 +229,9 @@ Dataset <- R6Class(
params_str) params_str)
} }
if (lgb.is.null.handle(handle)) {
stop("lgb.Dataset.construct: cannot create Dataset handle")
}
# Setup class and private type # Setup class and private type
class(handle) <- "lgb.Dataset.handle" class(handle) <- "lgb.Dataset.handle"
private$handle <- handle private$handle <- handle
......
...@@ -22,7 +22,7 @@ Predictor <- R6Class( ...@@ -22,7 +22,7 @@ Predictor <- R6Class(
params <- list(...) params <- list(...)
private$params <- lgb.params2str(params) private$params <- lgb.params2str(params)
# Create new lgb handle # Create new lgb handle
handle <- lgb.new.handle() handle <- 0.0
# Check if handle is a character # Check if handle is a character
if (is.character(modelfile)) { if (is.character(modelfile)) {
......
...@@ -6,13 +6,8 @@ lgb.is.Dataset <- function(x) { ...@@ -6,13 +6,8 @@ lgb.is.Dataset <- function(x) {
lgb.check.r6.class(x, "lgb.Dataset") # Checking if it is of class lgb.Dataset or not lgb.check.r6.class(x, "lgb.Dataset") # Checking if it is of class lgb.Dataset or not
} }
# use 64bit data to store address
lgb.new.handle <- function() {
0.0 # Return numeric type in R
}
lgb.is.null.handle <- function(x) { lgb.is.null.handle <- function(x) {
is.null(x) || x == 0 # Is it null or zero? is.null(x) || x == 0.0
} }
lgb.encode.char <- function(arr, len) { lgb.encode.char <- function(arr, len) {
...@@ -25,9 +20,8 @@ lgb.encode.char <- function(arr, len) { ...@@ -25,9 +20,8 @@ lgb.encode.char <- function(arr, len) {
} }
lgb.call <- function(fun_name, ret, ...) { lgb.call <- function(fun_name, ret, ...) {
# Set call state to a zero value # Set call state to a zero value
call_state <- 0L call_state <- as.integer(0L)
# Check for a ret call # Check for a ret call
if (!is.null(ret)) { if (!is.null(ret)) {
...@@ -35,10 +29,10 @@ lgb.call <- function(fun_name, ret, ...) { ...@@ -35,10 +29,10 @@ lgb.call <- function(fun_name, ret, ...) {
} else { } else {
call_state <- .Call(fun_name, ..., call_state, PACKAGE = "lib_lightgbm") # Call without ret call_state <- .Call(fun_name, ..., call_state, PACKAGE = "lib_lightgbm") # Call without ret
} }
call_state <- as.integer(call_state)
# Check for call state value post call # Check for call state value post call
if (call_state != 0L) { if (call_state != 0L) {
# Perform text error buffering # Perform text error buffering
buf_len <- 200L buf_len <- 200L
act_len <- 0L act_len <- 0L
...@@ -58,9 +52,9 @@ lgb.call <- function(fun_name, ret, ...) { ...@@ -58,9 +52,9 @@ lgb.call <- function(fun_name, ret, ...) {
# Return error # Return error
stop(paste0("api error: ", lgb.encode.char(err_msg, act_len))) stop(paste0("api error: ", lgb.encode.char(err_msg, act_len)))
} }
return(ret) return(ret)
} }
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#define CHECK_CALL(x) \ #define CHECK_CALL(x) \
if ((x) != 0) { \ if ((x) != 0) { \
R_INT_PTR(call_state)[0] = -1; \ R_INT_PTR(call_state)[0] = -1;\
return call_state; \ return call_state;\
} }
using namespace LightGBM; using namespace LightGBM;
...@@ -56,7 +56,7 @@ LGBM_SE LGBM_DatasetCreateFromFile_R(LGBM_SE filename, ...@@ -56,7 +56,7 @@ LGBM_SE LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
DatasetHandle handle; DatasetHandle handle = nullptr;
CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters), CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
R_GET_PTR(reference), &handle)); R_GET_PTR(reference), &handle));
R_SET_PTR(out, handle); R_SET_PTR(out, handle);
...@@ -81,7 +81,7 @@ LGBM_SE LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr, ...@@ -81,7 +81,7 @@ LGBM_SE LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr)); int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem)); int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row)); int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
DatasetHandle handle; DatasetHandle handle = nullptr;
CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices, CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle)); nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
...@@ -101,7 +101,7 @@ LGBM_SE LGBM_DatasetCreateFromMat_R(LGBM_SE data, ...@@ -101,7 +101,7 @@ LGBM_SE LGBM_DatasetCreateFromMat_R(LGBM_SE data,
int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row)); int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col)); int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
double* p_mat = R_REAL_PTR(data); double* p_mat = R_REAL_PTR(data);
DatasetHandle handle; DatasetHandle handle = nullptr;
CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle)); R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
R_SET_PTR(out, handle); R_SET_PTR(out, handle);
...@@ -123,7 +123,7 @@ LGBM_SE LGBM_DatasetGetSubset_R(LGBM_SE handle, ...@@ -123,7 +123,7 @@ LGBM_SE LGBM_DatasetGetSubset_R(LGBM_SE handle,
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1; idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
} }
DatasetHandle res; DatasetHandle res = nullptr;
CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle), CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
idxvec.data(), len, R_CHAR_PTR(parameters), idxvec.data(), len, R_CHAR_PTR(parameters),
&res)); &res));
...@@ -306,7 +306,7 @@ LGBM_SE LGBM_BoosterCreate_R(LGBM_SE train_data, ...@@ -306,7 +306,7 @@ LGBM_SE LGBM_BoosterCreate_R(LGBM_SE train_data,
LGBM_SE out, LGBM_SE out,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
BoosterHandle handle; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle)); CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
R_SET_PTR(out, handle); R_SET_PTR(out, handle);
R_API_END(); R_API_END();
...@@ -318,7 +318,7 @@ LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename, ...@@ -318,7 +318,7 @@ LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
R_API_BEGIN(); R_API_BEGIN();
int out_num_iterations = 0; int out_num_iterations = 0;
BoosterHandle handle; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
R_SET_PTR(out, handle); R_SET_PTR(out, handle);
R_API_END(); R_API_END();
...@@ -330,7 +330,7 @@ LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str, ...@@ -330,7 +330,7 @@ LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
R_API_BEGIN(); R_API_BEGIN();
int out_num_iterations = 0; int out_num_iterations = 0;
BoosterHandle handle; BoosterHandle handle = nullptr;
CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle)); CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
R_SET_PTR(out, handle); R_SET_PTR(out, handle);
R_API_END(); R_API_END();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment