Unverified Commit 9597326e authored by Tony Kenny's avatar Tony Kenny Committed by GitHub
Browse files

[R-package] construct dataset earlier in lgb.train and lgb.cv (fixes #3583) (#3598)



* construct dataset earlier in lgb.train and lgb.cv

* Update R-package/tests/testthat/test_dataset.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.cv.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.train.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* Update R-package/tests/testthat/test_dataset.R
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* fixing lint issues

* styling updates

* fix failing test
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent c02917e4
...@@ -164,6 +164,10 @@ lgb.cv <- function(params = list() ...@@ -164,6 +164,10 @@ lgb.cv <- function(params = list()
} }
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L end_iteration <- begin_iteration + params[["num_iterations"]] - 1L
# Construct datasets, if needed
data$update_params(params = params)
data$construct()
# Check interaction constraints # Check interaction constraints
cnames <- NULL cnames <- NULL
if (!is.null(colnames)) { if (!is.null(colnames)) {
...@@ -194,9 +198,6 @@ lgb.cv <- function(params = list() ...@@ -194,9 +198,6 @@ lgb.cv <- function(params = list()
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
} }
# Construct datasets, if needed
data$construct()
# Check for folds # Check for folds
if (!is.null(folds)) { if (!is.null(folds)) {
......
...@@ -142,6 +142,10 @@ lgb.train <- function(params = list(), ...@@ -142,6 +142,10 @@ lgb.train <- function(params = list(),
} }
end_iteration <- begin_iteration + params[["num_iterations"]] - 1L end_iteration <- begin_iteration + params[["num_iterations"]] - 1L
# Construct datasets, if needed
data$update_params(params = params)
data$construct()
# Check interaction constraints # Check interaction constraints
cnames <- NULL cnames <- NULL
if (!is.null(colnames)) { if (!is.null(colnames)) {
...@@ -167,8 +171,6 @@ lgb.train <- function(params = list(), ...@@ -167,8 +171,6 @@ lgb.train <- function(params = list(),
data$set_categorical_feature(categorical_feature) data$set_categorical_feature(categorical_feature)
} }
# Construct datasets, if needed
data$construct()
valid_contain_train <- FALSE valid_contain_train <- FALSE
train_data_name <- "train" train_data_name <- "train"
reduced_valid_sets <- list() reduced_valid_sets <- list()
......
...@@ -205,3 +205,63 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame ...@@ -205,3 +205,63 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame
expect_identical(new_params[[param_name]], updated_params[[param_name]]) expect_identical(new_params[[param_name]], updated_params[[param_name]])
} }
}) })
test_that("lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)
# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)
param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)
# should be able to train right away
bst <- lgb.train(
params = param
, data = dtest_read_in
)
expect_true(lgb.is.Booster(x = bst))
})
test_that("lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file", {
dtest <- lgb.Dataset(
data = test_data
, label = test_label
)
tmp_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = tmp_file
)
# read from a local file
dtest_read_in <- lgb.Dataset(data = tmp_file)
param <- list(
objective = "binary"
, metric = "binary_logloss"
, num_leaves = 5L
, learning_rate = 1.0
)
# should be able to train right away
bst <- lgb.cv(
params = param
, data = dtest_read_in
)
expect_is(bst, "lgb.CVBooster")
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment