"R-package/vscode:/vscode.git/clone" did not exist on "82886ba64451bc5c0827c04155d4b03952338bae"
Unverified Commit 22d6d1fd authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] Added tests on creating a Booster from a Dataset (#3007)

parent 4667d503
......@@ -227,3 +227,55 @@ test_that("If a string and a file are both passed to lgb.load() the file is used
pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2)
})
context("Booster")
test_that("Creating a Booster from a Dataset should work", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
dtrain <- lgb.Dataset(
agaricus.train$data
, label = agaricus.train$label
)
bst <- Booster$new(
params = list(
objective = "binary"
),
train_set = dtrain
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), 0L)
expect_true(is.na(bst$best_score))
expect_true(all(bst$predict(agaricus.train$data) == 0.5))
})
test_that("Creating a Booster from a Dataset with an existing predictor should work", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
data(agaricus.test, package = "lightgbm")
dtest <- Dataset$new(
data = agaricus.test$data
, label = agaricus.test$label
, predictor = bst$to_predictor()
)
bst_from_ds <- Booster$new(
train_set = dtest
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds)
expect_equal(bst$eval_train()[[1L]][["value"]], 0.1115352)
expect_equal(bst_from_ds$current_iter(), nrounds)
dumped_model <- jsonlite::fromJSON(bst$dump_model())
expect_identical(bst_from_ds$eval_train(), list())
expect_equal(bst_from_ds$current_iter(), nrounds)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment