"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b752170ba950890dd7e211fdecedef4d8deffb8b"
Unverified Commit 05d89a1a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[R-package] added tests on LGBM_BoosterResetTrainingData_R (#3020)

parent ad7f2851
...@@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", { ...@@ -311,3 +311,66 @@ test_that("Booster$rollback_one_iter() should work as expected", {
logloss <- bst$eval_train()[[1L]][["value"]] logloss <- bst$eval_train()[[1L]][["value"]]
expect_equal(logloss, 0.027915146) expect_equal(logloss, 0.027915146)
}) })
test_that("Booster$update() passing a train_set works as expected", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L
# train with 2 rounds and then update
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds)
bst$update(
train_set = Dataset$new(
data = agaricus.train$data
, label = agaricus.train$label
)
)
expect_true(lgb.is.Booster(bst))
expect_equal(bst$current_iter(), nrounds + 1L)
# train with 3 rounds directlry
bst2 <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds + 1L
, objective = "binary"
)
expect_true(lgb.is.Booster(bst2))
expect_equal(bst2$current_iter(), nrounds + 1L)
# model with 2 rounds + 1 update should be identical to 3 rounds
expect_equal(bst2$eval_train()[[1L]][["value"]], 0.04806585)
expect_equal(bst$eval_train()[[1L]][["value"]], bst2$eval_train()[[1L]][["value"]])
})
test_that("Booster$update() throws an informative error if you provide a non-Dataset to update()", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
nrounds <- 2L
# train with 2 rounds and then update
bst <- lightgbm(
data = as.matrix(agaricus.train$data)
, label = agaricus.train$label
, num_leaves = 4L
, learning_rate = 1.0
, nrounds = nrounds
, objective = "binary"
)
expect_error({
bst$update(
train_set = data.frame(x = rnorm(10L))
)
}, regexp = "lgb.Booster.update: Only can use lgb.Dataset", fixed = TRUE)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment