"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "4162485da0abb583e548af640c61a8fd6c5a2b06"
Unverified Commit a08c37f6 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] avoid unnecessary computation and add tests for Dataset set_reference() method (#4587)



* [R-package] avoid unnecessary computation in Dataset set_reference() method

* re-arrange conditions

* do more validation upfront and add tests

* Update R-package/tests/testthat/test_dataset.R
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update R-package/tests/testthat/test_dataset.R
Co-authored-by: default avatarNikita Titov <nekit94-12@hotmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 79463dfb
...@@ -663,34 +663,27 @@ Dataset <- R6::R6Class( ...@@ -663,34 +663,27 @@ Dataset <- R6::R6Class(
# Set reference # Set reference
set_reference = function(reference) { set_reference = function(reference) {
# Set known references # setting reference to this same Dataset object doesn't require any changes
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(colnames = reference$get_colnames())
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)
# Check for identical references
if (identical(private$reference, reference)) { if (identical(private$reference, reference)) {
return(invisible(self)) return(invisible(self))
} }
# Check for empty data # changing the reference removes the Dataset object on the C++ side, so it should only
# be done if you still have the raw_data available, so that the new Dataset can be reconstructed
if (is.null(private$raw_data)) { if (is.null(private$raw_data)) {
stop("set_reference: cannot set reference after freeing raw data, stop("set_reference: cannot set reference after freeing raw data,
please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset") please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
} }
# Check for non-existing reference if (!lgb.is.Dataset(reference)) {
if (!is.null(reference)) { stop("set_reference: Can only use lgb.Dataset as a reference")
# Reference is unknown
if (!lgb.is.Dataset(reference)) {
stop("set_reference: Can only use lgb.Dataset as a reference")
}
} }
# Set known references
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(colnames = reference$get_colnames())
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)
# Store reference # Store reference
private$reference <- reference private$reference <- reference
......
context("testing lgb.Dataset functionality") context("testing lgb.Dataset functionality")
data(agaricus.train, package = "lightgbm")
train_data <- agaricus.train$data[seq_len(1000L), ]
train_label <- agaricus.train$label[seq_len(1000L)]
data(agaricus.test, package = "lightgbm") data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ] test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L] test_label <- agaricus.test$label[1L:100L]
...@@ -74,6 +78,118 @@ test_that("Dataset$slice() supports passing Dataset attributes through '...'", { ...@@ -74,6 +78,118 @@ test_that("Dataset$slice() supports passing Dataset attributes through '...'", {
expect_identical(dsub1$getinfo("init_score"), init_score) expect_identical(dsub1$getinfo("init_score"), init_score)
}) })
test_that("Dataset$set_reference() on a constructed Dataset fails if raw data has been freed", {
dtrain <- lgb.Dataset(train_data, label = train_label)
dtrain$construct()
dtest <- lgb.Dataset(test_data, label = test_label)
dtest$construct()
expect_error({
dtest$set_reference(dtrain)
}, regexp = "cannot set reference after freeing raw data")
})
test_that("Dataset$set_reference() fails if reference is not a Dataset", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
)
expect_error({
dtrain$set_reference(reference = data.frame(x = rnorm(10L)))
}, regexp = "Can only use lgb.Dataset as a reference")
# passing NULL when the Dataset already has a reference raises an error
dtest <- lgb.Dataset(
test_data
, label = test_label
, free_raw_data = FALSE
)
dtrain$set_reference(dtest)
expect_error({
dtrain$set_reference(reference = NULL)
}, regexp = "Can only use lgb.Dataset as a reference")
})
test_that("Dataset$set_reference() setting reference to the same Dataset has no side effects", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
, categorical_feature = c(2L, 3L)
)
dtrain$construct()
cat_features_before <- dtrain$.__enclos_env__$private$categorical_feature
colnames_before <- dtrain$get_colnames()
predictor_before <- dtrain$.__enclos_env__$private$predictor
dtrain$set_reference(dtrain)
expect_identical(
cat_features_before
, dtrain$.__enclos_env__$private$categorical_feature
)
expect_identical(
colnames_before
, dtrain$get_colnames()
)
expect_identical(
predictor_before
, dtrain$.__enclos_env__$private$predictor
)
})
test_that("Dataset$set_reference() updates categorical_feature, colnames, and predictor", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
, categorical_feature = c(2L, 3L)
)
dtrain$construct()
bst <- Booster$new(
train_set = dtrain
, params = list(verbose = -1L)
)
dtrain$.__enclos_env__$private$predictor <- bst$to_predictor()
test_original_feature_names <- paste0("feature_col_", seq_len(ncol(test_data)))
dtest <- lgb.Dataset(
test_data
, label = test_label
, free_raw_data = FALSE
, colnames = test_original_feature_names
)
dtest$construct()
# at this point, dtest should not have categorical_feature
expect_null(dtest$.__enclos_env__$private$predictor)
expect_null(dtest$.__enclos_env__$private$categorical_feature)
expect_identical(
dtest$get_colnames()
, test_original_feature_names
)
dtest$set_reference(dtrain)
# after setting reference to dtrain, those attributes should have dtrain's values
expect_is(dtest$.__enclos_env__$private$predictor, "lgb.Predictor")
expect_identical(
dtest$.__enclos_env__$private$predictor$.__enclos_env__$private$handle
, dtrain$.__enclos_env__$private$predictor$.__enclos_env__$private$handle
)
expect_identical(
dtest$.__enclos_env__$private$categorical_feature
, dtrain$.__enclos_env__$private$categorical_feature
)
expect_identical(
dtest$get_colnames()
, dtrain$get_colnames()
)
expect_false(
identical(dtest$get_colnames(), test_original_feature_names)
)
})
test_that("lgb.Dataset: colnames", { test_that("lgb.Dataset: colnames", {
dtest <- lgb.Dataset(test_data, label = test_label) dtest <- lgb.Dataset(test_data, label = test_label)
expect_equal(colnames(dtest), colnames(test_data)) expect_equal(colnames(dtest), colnames(test_data))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment