Unverified Commit 170a9304 authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[R-package] Fix error when passing categorical features to lightgbm() (fixes #6000) (#6003)

parent 665c4731
...@@ -154,6 +154,9 @@ lgb.train <- function(params = list(), ...@@ -154,6 +154,9 @@ lgb.train <- function(params = list(),
# Construct datasets, if needed # Construct datasets, if needed
data$update_params(params = params) data$update_params(params = params)
if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature)
}
data$construct() data$construct()
# Check interaction constraints # Check interaction constraints
...@@ -179,11 +182,6 @@ lgb.train <- function(params = list(), ...@@ -179,11 +182,6 @@ lgb.train <- function(params = list(),
data$set_colnames(colnames) data$set_colnames(colnames)
} }
# Write categorical features
if (!is.null(categorical_feature)) {
data$set_categorical_feature(categorical_feature)
}
valid_contain_train <- FALSE valid_contain_train <- FALSE
train_data_name <- "train" train_data_name <- "train"
reduced_valid_sets <- list() reduced_valid_sets <- list()
......
...@@ -3773,3 +3773,18 @@ test_that("lightgbm() model predictions retain factor levels for binary classifi ...@@ -3773,3 +3773,18 @@ test_that("lightgbm() model predictions retain factor levels for binary classifi
expect_true(is.numeric(pred)) expect_true(is.numeric(pred))
expect_false(any(pred %in% y)) expect_false(any(pred %in% y))
}) })
test_that("lightgbm() accepts named categorical_features", {
data(mtcars)
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1L])
model <- lightgbm(
x
, y
, categorical_feature = "cyl"
, verbose = .LGB_VERBOSITY
, nrounds = 5L
, num_threads = .LGB_MAX_THREADS
)
expect_true(length(model$params$categorical_feature) > 0L)
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment