Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
c676a7ea
Unverified
Commit
c676a7ea
authored
Feb 14, 2023
by
david-cortes
Committed by
GitHub
Feb 13, 2023
Browse files
[R-package] Accept factor labels and use their levels (#5341)
parent
9713ff40
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
313 additions
and
10 deletions
+313
-10
R-package/DESCRIPTION
R-package/DESCRIPTION
+1
-1
R-package/R/aliases.R
R-package/R/aliases.R
+24
-0
R-package/R/lgb.Booster.R
R-package/R/lgb.Booster.R
+19
-2
R-package/R/lgb.DataProcessor.R
R-package/R/lgb.DataProcessor.R
+94
-0
R-package/R/lightgbm.R
R-package/R/lightgbm.R
+27
-1
R-package/man/lgb.configure_fast_predict.Rd
R-package/man/lgb.configure_fast_predict.Rd
+6
-1
R-package/man/lightgbm.Rd
R-package/man/lightgbm.Rd
+13
-4
R-package/man/predict.lgb.Booster.Rd
R-package/man/predict.lgb.Booster.Rd
+12
-1
R-package/tests/testthat/test_basic.R
R-package/tests/testthat/test_basic.R
+117
-0
No files found.
R-package/DESCRIPTION
View file @
c676a7ea
...
...
@@ -63,4 +63,4 @@ Imports:
utils
SystemRequirements:
C++11
RoxygenNote: 7.2.
1
RoxygenNote: 7.2.
3
R-package/R/aliases.R
View file @
c676a7ea
...
...
@@ -78,3 +78,27 @@
)
)
}
.MULTICLASS_OBJECTIVES
<-
function
()
{
return
(
c
(
"multi_logloss"
,
"multiclass"
,
"softmax"
,
"multiclassova"
,
"multiclass_ova"
,
"ova"
,
"ovr"
)
)
}
.BINARY_OBJECTIVES
<-
function
()
{
return
(
c
(
"binary_logloss"
,
"binary"
,
"binary_error"
)
)
}
R-package/R/lgb.Booster.R
View file @
c676a7ea
...
...
@@ -9,6 +9,7 @@ Booster <- R6::R6Class(
best_score
=
NA_real_
,
params
=
list
(),
record_evals
=
list
(),
data_processor
=
NULL
,
# Finalize will free up the handles
finalize
=
function
()
{
...
...
@@ -837,6 +838,11 @@ Booster <- R6::R6Class(
#'
#' Note that, if using custom objectives, types "class" and "response" will not be available and will
#' default towards using "raw" instead.
#'
#' If the model was fit through function \link{lightgbm} and it was passed a factor as labels,
#' passing the prediction type through \code{params} instead of through this argument might
#' result in factor levels for classification objectives not being applied correctly to the
#' resulting output.
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
...
...
@@ -895,6 +901,11 @@ NULL
#' in the order "feature contributions for first class, feature contributions for second class, feature
#' contributions for third class, etc.".
#'
#' If the model was fit through function \link{lightgbm} and it was passed a factor as labels, predictions
#' returned from this function will retain the factor levels (either as values for \code{type="class"}, or
#' as column names for \code{type="response"} and \code{type="raw"} for multi-class objectives). Note that
#' passing the requested prediction type under \code{params} instead of through \code{type} might result in
#' the factor levels not being present in the output.
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
...
...
@@ -996,12 +1007,18 @@ predict.lgb.Booster <- function(object,
,
params
=
params
)
if
(
type
==
"class"
)
{
if
(
object
$
params
$
objective
==
"binary"
)
{
if
(
object
$
params
$
objective
%in%
.BINARY_OBJECTIVES
()
)
{
pred
<-
as.integer
(
pred
>=
0.5
)
}
else
if
(
object
$
params
$
objective
%in%
c
(
"multiclass"
,
"multiclassova"
))
{
}
else
if
(
object
$
params
$
objective
%in%
.MULTICLASS_OBJECTIVES
(
))
{
pred
<-
max.col
(
pred
)
-
1L
}
}
if
(
!
is.null
(
object
$
data_processor
))
{
pred
<-
object
$
data_processor
$
process_predictions
(
pred
=
pred
,
type
=
type
)
}
return
(
pred
)
}
...
...
R-package/R/lgb.DataProcessor.R
0 → 100644
View file @
c676a7ea
DataProcessor
<-
R6
::
R6Class
(
classname
=
"lgb.DataProcessor"
,
public
=
list
(
factor_levels
=
NULL
,
process_label
=
function
(
label
,
objective
,
params
)
{
if
(
is.character
(
label
))
{
label
<-
factor
(
label
)
}
if
(
is.factor
(
label
))
{
self
$
factor_levels
<-
levels
(
label
)
if
(
length
(
self
$
factor_levels
)
<=
1L
)
{
stop
(
"Labels to predict is a factor with <2 possible values."
)
}
label
<-
as.numeric
(
label
)
-
1.0
out
<-
list
(
label
=
label
)
if
(
length
(
self
$
factor_levels
)
==
2L
)
{
if
(
objective
==
"auto"
)
{
objective
<-
"binary"
}
if
(
!
(
objective
%in%
.BINARY_OBJECTIVES
()))
{
stop
(
"Two-level factors as labels only allowed for objective='binary' or objective='auto'."
)
}
}
else
{
if
(
objective
==
"auto"
)
{
objective
<-
"multiclass"
}
if
(
!
(
objective
%in%
.MULTICLASS_OBJECTIVES
()))
{
stop
(
sprintf
(
"Factors with >2 levels as labels only allowed for multi-class objectives. Got: %s (allowed: %s)"
,
objective
,
toString
(
.MULTICLASS_OBJECTIVES
())
)
)
}
data_num_class
<-
length
(
self
$
factor_levels
)
params
<-
lgb.check.wrapper_param
(
main_param_name
=
"num_class"
,
params
=
params
,
alternative_kwarg_value
=
data_num_class
)
if
(
params
[[
"num_class"
]]
!=
data_num_class
)
{
warning
(
sprintf
(
"Found num_class=%d in params, but 'label' is a factor with %d levels. 'num_class' will be ignored."
,
params
[[
"num_class"
]]
,
data_num_class
)
)
params
$
num_class
<-
data_num_class
}
}
out
$
objective
<-
objective
out
$
params
<-
params
return
(
out
)
}
else
{
label
<-
as.numeric
(
label
)
if
(
objective
==
"auto"
)
{
objective
<-
"regression"
}
out
<-
list
(
label
=
label
,
objective
=
objective
,
params
=
params
)
return
(
out
)
}
},
process_predictions
=
function
(
pred
,
type
)
{
if
(
NROW
(
self
$
factor_levels
))
{
if
(
type
==
"class"
)
{
pred
<-
as.integer
(
pred
)
+
1L
attributes
(
pred
)
$
levels
<-
self
$
factor_levels
attributes
(
pred
)
$
class
<-
"factor"
}
else
if
(
type
%in%
c
(
"response"
,
"raw"
))
{
if
(
is.matrix
(
pred
)
&&
ncol
(
pred
)
==
length
(
self
$
factor_levels
))
{
colnames
(
pred
)
<-
self
$
factor_levels
}
}
}
return
(
pred
)
}
)
)
R-package/R/lightgbm.R
View file @
c676a7ea
...
...
@@ -103,6 +103,15 @@ NULL
#' For a list of accepted objectives, see
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective}{
#' the "objective" item of the "Parameters" section of the documentation}.
#'
#' If passing \code{"auto"} and \code{data} is not of type \code{lgb.Dataset}, the objective will
#' be determined according to what is passed for \code{label}:\itemize{
#' \item If passing a factor with two variables, will use objective \code{"binary"}.
#' \item If passing a factor with more than two variables, will use objective \code{"multiclass"}
#' (note that parameter \code{num_class} in this case will also be determined automatically from
#' \code{label}).
#' \item Otherwise, will use objective \code{"regression"}.
#' }
#' @param init_score initial score is the base prediction lightgbm will boost from
#' @param num_threads Number of parallel threads to use. For best speed, this should be set to the number of
#' physical cores in the CPU - in a typical x86-64 machine, this corresponds to half the
...
...
@@ -149,7 +158,7 @@ lightgbm <- function(data,
init_model
=
NULL
,
callbacks
=
list
(),
serializable
=
TRUE
,
objective
=
"
regression
"
,
objective
=
"
auto
"
,
init_score
=
NULL
,
num_threads
=
NULL
,
...
)
{
...
...
@@ -173,6 +182,22 @@ lightgbm <- function(data,
,
alternative_kwarg_value
=
verbose
)
# Process factors as labels and auto-determine objective
if
(
!
lgb.is.Dataset
(
data
))
{
data_processor
<-
DataProcessor
$
new
()
temp
<-
data_processor
$
process_label
(
label
=
label
,
objective
=
objective
,
params
=
params
)
label
<-
temp
$
label
objective
<-
temp
$
objective
params
<-
temp
$
params
rm
(
temp
)
}
else
{
data_processor
<-
NULL
}
# Set data to a temporary variable
dtrain
<-
data
...
...
@@ -204,6 +229,7 @@ lightgbm <- function(data,
what
=
lgb.train
,
args
=
train_args
)
bst
$
data_processor
<-
data_processor
return
(
bst
)
}
...
...
R-package/man/lgb.configure_fast_predict.Rd
View file @
c676a7ea
...
...
@@ -51,7 +51,12 @@ If <= 0, all iterations from start_iteration are used (no limits).}
}
Note that, if using custom objectives, types "class" and "response" will not be available and will
default towards using "raw" instead.}
default towards using "raw" instead.
If the model was fit through function \link{lightgbm} and it was passed a factor as labels,
passing the prediction type through \code{params} instead of through this argument might
result in factor levels for classification objectives not being applied correctly to the
resulting output.}
\item{params}{a list of additional named parameters. See
\href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
...
...
R-package/man/lightgbm.Rd
View file @
c676a7ea
...
...
@@ -16,7 +16,7 @@ lightgbm(
init_model = NULL,
callbacks = list(),
serializable = TRUE,
objective = "
regression
",
objective = "
auto
",
init_score = NULL,
num_threads = NULL,
...
...
...
@@ -56,9 +56,18 @@ set to the iteration number of the best iteration.}
\code{save} or \code{saveRDS} (see section "Model serialization").}
\item{objective}{Optimization objective (e.g. `"regression"`, `"binary"`, etc.).
For a list of accepted objectives, see
\href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective}{
the "objective" item of the "Parameters" section of the documentation}.}
For a list of accepted objectives, see
\href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective}{
the "objective" item of the "Parameters" section of the documentation}.
If passing \code{"auto"} and \code{data} is not of type \code{lgb.Dataset}, the objective will
be determined according to what is passed for \code{label}:\itemize{
\item If passing a factor with two variables, will use objective \code{"binary"}.
\item If passing a factor with more than two variables, will use objective \code{"multiclass"}
(note that parameter \code{num_class} in this case will also be determined automatically from
\code{label}).
\item Otherwise, will use objective \code{"regression"}.
}}
\item{init_score}{initial score is the base prediction lightgbm will boost from}
...
...
R-package/man/predict.lgb.Booster.Rd
View file @
c676a7ea
...
...
@@ -49,7 +49,12 @@
}
Note that, if using custom objectives, types "class" and "response" will not be available and will
default towards using "raw" instead.}
default towards using "raw" instead.
If the model was fit through function \link{lightgbm} and it was passed a factor as labels,
passing the prediction type through \code{params} instead of through this argument might
result in factor levels for classification objectives not being applied correctly to the
resulting output.}
\item{start_iteration}{int or None, optional (default=None)
Start index of the iteration to predict.
...
...
@@ -92,6 +97,12 @@ For prediction types that are meant to always return one output per observation
Shapley base value. For multiclass objectives, this matrix will represent \code{num_classes} such matrices,
in the order "feature contributions for first class, feature contributions for second class, feature
contributions for third class, etc.".
If the model was fit through function \link{lightgbm} and it was passed a factor as labels, predictions
returned from this function will retain the factor levels (either as values for \code{type="class"}, or
as column names for \code{type="response"} and \code{type="raw"} for multi-class objectives). Note that
passing the requested prediction type under \code{params} instead of through \code{type} might result in
the factor levels not being present in the output.
}
\description{
Predicted values based on class \code{lgb.Booster}
...
...
R-package/tests/testthat/test_basic.R
View file @
c676a7ea
...
...
@@ -3529,3 +3529,120 @@ test_that("lgb.cv() only prints eval metrics when expected to", {
fitted_model
=
out
[[
"booster"
]]
)
})
test_that
(
"lightgbm() changes objective='auto' appropriately"
,
{
# Regression
data
(
"mtcars"
)
y
<-
mtcars
$
mpg
x
<-
as.matrix
(
mtcars
[,
-1L
])
model
<-
lightgbm
(
x
,
y
,
objective
=
"auto"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
expect_equal
(
model
$
params
$
objective
,
"regression"
)
model_txt_lines
<-
strsplit
(
x
=
model
$
save_model_to_string
()
,
split
=
"\n"
,
fixed
=
TRUE
)[[
1L
]]
expect_true
(
any
(
grepl
(
"objective=regression"
,
model_txt_lines
,
fixed
=
TRUE
)))
expect_false
(
any
(
grepl
(
"objective=regression_l1"
,
model_txt_lines
,
fixed
=
TRUE
)))
# Binary classification
x
<-
train
$
data
y
<-
factor
(
train
$
label
)
model
<-
lightgbm
(
x
,
y
,
objective
=
"auto"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
expect_equal
(
model
$
params
$
objective
,
"binary"
)
model_txt_lines
<-
strsplit
(
x
=
model
$
save_model_to_string
()
,
split
=
"\n"
,
fixed
=
TRUE
)[[
1L
]]
expect_true
(
any
(
grepl
(
"objective=binary"
,
model_txt_lines
,
fixed
=
TRUE
)))
# Multi-class classification
data
(
"iris"
)
y
<-
factor
(
iris
$
Species
)
x
<-
as.matrix
(
iris
[,
-5L
])
model
<-
lightgbm
(
x
,
y
,
objective
=
"auto"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
expect_equal
(
model
$
params
$
objective
,
"multiclass"
)
expect_equal
(
model
$
params
$
num_class
,
3L
)
model_txt_lines
<-
strsplit
(
x
=
model
$
save_model_to_string
()
,
split
=
"\n"
,
fixed
=
TRUE
)[[
1L
]]
expect_true
(
any
(
grepl
(
"objective=multiclass"
,
model_txt_lines
,
fixed
=
TRUE
)))
})
test_that
(
"lightgbm() determines number of classes for non-default multiclass objectives"
,
{
data
(
"iris"
)
y
<-
factor
(
iris
$
Species
)
x
<-
as.matrix
(
iris
[,
-5L
])
model
<-
lightgbm
(
x
,
y
,
objective
=
"multiclassova"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
expect_equal
(
model
$
params
$
objective
,
"multiclassova"
)
expect_equal
(
model
$
params
$
num_class
,
3L
)
model_txt_lines
<-
strsplit
(
x
=
model
$
save_model_to_string
()
,
split
=
"\n"
,
fixed
=
TRUE
)[[
1L
]]
expect_true
(
any
(
grepl
(
"objective=multiclassova"
,
model_txt_lines
,
fixed
=
TRUE
)))
})
test_that
(
"lightgbm() doesn't accept binary classification with non-binary factors"
,
{
data
(
"iris"
)
y
<-
factor
(
iris
$
Species
)
x
<-
as.matrix
(
iris
[,
-5L
])
expect_error
({
lightgbm
(
x
,
y
,
objective
=
"binary"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
},
regexp
=
"Factors with >2 levels as labels only allowed for multi-class objectives"
)
})
test_that
(
"lightgbm() doesn't accept multi-class classification with binary factors"
,
{
data
(
"iris"
)
y
<-
as.character
(
iris
$
Species
)
y
[
y
==
"setosa"
]
<-
"versicolor"
y
<-
factor
(
y
)
x
<-
as.matrix
(
iris
[,
-5L
])
expect_error
({
lightgbm
(
x
,
y
,
objective
=
"multiclass"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
},
regexp
=
"Two-level factors as labels only allowed for objective='binary'"
)
})
test_that
(
"lightgbm() model predictions retain factor levels for multiclass classification"
,
{
data
(
"iris"
)
y
<-
factor
(
iris
$
Species
)
x
<-
as.matrix
(
iris
[,
-5L
])
model
<-
lightgbm
(
x
,
y
,
objective
=
"auto"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
pred
<-
predict
(
model
,
x
,
type
=
"class"
)
expect_true
(
is.factor
(
pred
))
expect_equal
(
levels
(
pred
),
levels
(
y
))
pred
<-
predict
(
model
,
x
,
type
=
"response"
)
expect_equal
(
colnames
(
pred
),
levels
(
y
))
pred
<-
predict
(
model
,
x
,
type
=
"raw"
)
expect_equal
(
colnames
(
pred
),
levels
(
y
))
})
test_that
(
"lightgbm() model predictions retain factor levels for binary classification"
,
{
data
(
"iris"
)
y
<-
as.character
(
iris
$
Species
)
y
[
y
==
"setosa"
]
<-
"versicolor"
y
<-
factor
(
y
)
x
<-
as.matrix
(
iris
[,
-5L
])
model
<-
lightgbm
(
x
,
y
,
objective
=
"auto"
,
verbose
=
VERBOSITY
,
nrounds
=
5L
)
pred
<-
predict
(
model
,
x
,
type
=
"class"
)
expect_true
(
is.factor
(
pred
))
expect_equal
(
levels
(
pred
),
levels
(
y
))
pred
<-
predict
(
model
,
x
,
type
=
"response"
)
expect_true
(
is.vector
(
pred
))
expect_true
(
is.numeric
(
pred
))
expect_false
(
any
(
pred
%in%
y
))
pred
<-
predict
(
model
,
x
,
type
=
"raw"
)
expect_true
(
is.vector
(
pred
))
expect_true
(
is.numeric
(
pred
))
expect_false
(
any
(
pred
%in%
y
))
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment