Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
d9064282
Unverified
Commit
d9064282
authored
Jan 25, 2020
by
James Lamb
Committed by
GitHub
Jan 25, 2020
Browse files
[R-package] moved parameter validations up earlier in function calls (#2663)
parent
08fd53cd
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
267 additions
and
79 deletions
+267
-79
R-package/R/callback.R
R-package/R/callback.R
+10
-8
R-package/R/lgb.Booster.R
R-package/R/lgb.Booster.R
+11
-5
R-package/R/lgb.Dataset.R
R-package/R/lgb.Dataset.R
+8
-14
R-package/R/lgb.cv.R
R-package/R/lgb.cv.R
+14
-13
R-package/R/lgb.importance.R
R-package/R/lgb.importance.R
+1
-1
R-package/R/lgb.train.R
R-package/R/lgb.train.R
+17
-28
R-package/R/lightgbm.R
R-package/R/lightgbm.R
+10
-4
R-package/man/lgb.cv.Rd
R-package/man/lgb.cv.Rd
+4
-2
R-package/man/lgb.train.Rd
R-package/man/lgb.train.Rd
+3
-1
R-package/man/lgb_shared_params.Rd
R-package/man/lgb_shared_params.Rd
+3
-1
R-package/man/lightgbm.Rd
R-package/man/lightgbm.Rd
+3
-1
R-package/tests/testthat/test_basic.R
R-package/tests/testthat/test_basic.R
+135
-1
R-package/tests/testthat/test_dataset.R
R-package/tests/testthat/test_dataset.R
+27
-0
R-package/tests/testthat/test_utils.R
R-package/tests/testthat/test_utils.R
+21
-0
No files found.
R-package/R/callback.R
View file @
d9064282
...
@@ -29,11 +29,10 @@ cb.reset.parameters <- function(new_params) {
...
@@ -29,11 +29,10 @@ cb.reset.parameters <- function(new_params) {
# Run some checks in the beginning
# Run some checks in the beginning
init
<-
function
(
env
)
{
init
<-
function
(
env
)
{
# Store boosting rounds
nrounds
<<-
env
$
end_iteration
-
env
$
begin_iteration
+
1L
# Check for model environment
# Check for model environment
if
(
is.null
(
env
$
model
))
{
stop
(
"Env should have a "
,
sQuote
(
"model"
))
}
if
(
is.null
(
env
$
model
))
{
stop
(
"Env should have a "
,
sQuote
(
"model"
))
}
# Some parameters are not allowed to be changed,
# Some parameters are not allowed to be changed,
# since changing them would simply wreck some chaos
# since changing them would simply wreck some chaos
...
@@ -50,6 +49,9 @@ cb.reset.parameters <- function(new_params) {
...
@@ -50,6 +49,9 @@ cb.reset.parameters <- function(new_params) {
)
)
}
}
# Store boosting rounds
nrounds
<<-
env
$
end_iteration
-
env
$
begin_iteration
+
1L
# Check parameter names
# Check parameter names
for
(
n
in
pnames
)
{
for
(
n
in
pnames
)
{
...
@@ -285,14 +287,14 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
...
@@ -285,14 +287,14 @@ cb.early.stop <- function(stopping_rounds, verbose = TRUE) {
# Initialization function
# Initialization function
init
<-
function
(
env
)
{
init
<-
function
(
env
)
{
# Store evaluation length
eval_len
<<-
length
(
env
$
eval_list
)
# Early stopping cannot work without metrics
# Early stopping cannot work without metrics
if
(
eval_l
en
==
0L
)
{
if
(
length
(
env
$
eval_l
ist
)
==
0L
)
{
stop
(
"For early stopping, valids must have at least one element"
)
stop
(
"For early stopping, valids must have at least one element"
)
}
}
# Store evaluation length
eval_len
<<-
length
(
env
$
eval_list
)
# Check if verbose or not
# Check if verbose or not
if
(
isTRUE
(
verbose
))
{
if
(
isTRUE
(
verbose
))
{
cat
(
"Will train until there is no improvement in "
,
stopping_rounds
,
" rounds.\n\n"
,
sep
=
""
)
cat
(
"Will train until there is no improvement in "
,
stopping_rounds
,
" rounds.\n\n"
,
sep
=
""
)
...
...
R-package/R/lgb.Booster.R
View file @
d9064282
...
@@ -781,15 +781,21 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
...
@@ -781,15 +781,21 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
}
}
# Return new booster
# Return new booster
if
(
!
is.null
(
filename
)
&&
!
file.exists
(
filename
))
stop
(
"lgb.load: file does not exist for supplied filename"
)
if
(
!
is.null
(
filename
)
&&
!
file.exists
(
filename
))
{
if
(
!
is.null
(
filename
))
return
(
invisible
(
Booster
$
new
(
modelfile
=
filename
)))
stop
(
"lgb.load: file does not exist for supplied filename"
)
}
if
(
!
is.null
(
filename
))
{
return
(
invisible
(
Booster
$
new
(
modelfile
=
filename
)))
}
# Load from model_str
# Load from model_str
if
(
!
is.null
(
model_str
)
&&
!
is.character
(
model_str
))
{
if
(
!
is.null
(
model_str
)
&&
!
is.character
(
model_str
))
{
stop
(
"lgb.load: model_str should be character"
)
stop
(
"lgb.load: model_str should be character"
)
}
}
# Return new booster
# Return new booster
if
(
!
is.null
(
model_str
))
return
(
invisible
(
Booster
$
new
(
model_str
=
model_str
)))
if
(
!
is.null
(
model_str
))
{
return
(
invisible
(
Booster
$
new
(
model_str
=
model_str
)))
}
}
}
...
@@ -831,8 +837,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
...
@@ -831,8 +837,8 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
}
}
# Check if file name is character
# Check if file name is character
if
(
!
is.character
(
filename
))
{
if
(
!
(
is.character
(
filename
)
&&
length
(
filename
)
==
1L
)
)
{
stop
(
"lgb.save: filename should be a
character
"
)
stop
(
"lgb.save: filename should be a
string
"
)
}
}
# Store booster
# Store booster
...
...
R-package/R/lgb.Dataset.R
View file @
d9064282
...
@@ -32,6 +32,14 @@ Dataset <- R6::R6Class(
...
@@ -32,6 +32,14 @@ Dataset <- R6::R6Class(
info
=
list
(),
info
=
list
(),
...
)
{
...
)
{
# validate inputs early to avoid unnecessary computation
if
(
!
(
is.null
(
reference
)
||
lgb.check.r6.class
(
reference
,
"lgb.Dataset"
)))
{
stop
(
"lgb.Dataset: If provided, reference must be a "
,
sQuote
(
"lgb.Dataset"
))
}
if
(
!
(
is.null
(
predictor
)
||
lgb.check.r6.class
(
predictor
,
"lgb.Predictor"
)))
{
stop
(
"lgb.Dataset: If provided, predictor must be a "
,
sQuote
(
"lgb.Predictor"
))
}
# Check for additional parameters
# Check for additional parameters
additional_params
<-
list
(
...
)
additional_params
<-
list
(
...
)
...
@@ -56,20 +64,6 @@ Dataset <- R6::R6Class(
...
@@ -56,20 +64,6 @@ Dataset <- R6::R6Class(
}
}
# Check for dataset reference
if
(
!
is.null
(
reference
))
{
if
(
!
lgb.check.r6.class
(
reference
,
"lgb.Dataset"
))
{
stop
(
"lgb.Dataset: Can only use "
,
sQuote
(
"lgb.Dataset"
),
" as reference"
)
}
}
# Check for predictor reference
if
(
!
is.null
(
predictor
))
{
if
(
!
lgb.check.r6.class
(
predictor
,
"lgb.Predictor"
))
{
stop
(
"lgb.Dataset: Only can use "
,
sQuote
(
"lgb.Predictor"
),
" as predictor"
)
}
}
# Check for matrix format
# Check for matrix format
if
(
is.matrix
(
data
))
{
if
(
is.matrix
(
data
))
{
# Check whether matrix is the correct type first ("double")
# Check whether matrix is the correct type first ("double")
...
...
R-package/R/lgb.cv.R
View file @
d9064282
...
@@ -22,7 +22,7 @@ CVBooster <- R6::R6Class(
...
@@ -22,7 +22,7 @@ CVBooster <- R6::R6Class(
#' @description Cross validation logic used by LightGBM
#' @description Cross validation logic used by LightGBM
#' @inheritParams lgb_shared_params
#' @inheritParams lgb_shared_params
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples.
#' @param label
v
ector of
response values. Should be provided only when
data is
an R-matrix.
#' @param label
V
ector of
labels, used if \code{
data
}
is
not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param obj objective function, can be character or custom objective function. Examples include
#' @param obj objective function, can be character or custom objective function. Examples include
#' \code{regression}, \code{regression_l1}, \code{huber},
#' \code{regression}, \code{regression_l1}, \code{huber},
...
@@ -95,6 +95,19 @@ lgb.cv <- function(params = list()
...
@@ -95,6 +95,19 @@ lgb.cv <- function(params = list()
,
...
,
...
)
{
)
{
# validate parameters
if
(
nrounds
<=
0L
)
{
stop
(
"nrounds should be greater than zero"
)
}
# If 'data' is not an lgb.Dataset, try to construct one using 'label'
if
(
!
lgb.is.Dataset
(
data
))
{
if
(
is.null
(
label
))
{
stop
(
"'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'"
)
}
data
<-
lgb.Dataset
(
data
,
label
=
label
)
}
# Setup temporary variables
# Setup temporary variables
params
<-
append
(
params
,
list
(
...
))
params
<-
append
(
params
,
list
(
...
))
params
$
verbose
<-
verbose
params
$
verbose
<-
verbose
...
@@ -103,10 +116,6 @@ lgb.cv <- function(params = list()
...
@@ -103,10 +116,6 @@ lgb.cv <- function(params = list()
fobj
<-
NULL
fobj
<-
NULL
feval
<-
NULL
feval
<-
NULL
if
(
nrounds
<=
0L
)
{
stop
(
"nrounds should be greater than zero"
)
}
# Check for objective (function or not)
# Check for objective (function or not)
if
(
is.function
(
params
$
objective
))
{
if
(
is.function
(
params
$
objective
))
{
fobj
<-
params
$
objective
fobj
<-
params
$
objective
...
@@ -141,14 +150,6 @@ lgb.cv <- function(params = list()
...
@@ -141,14 +150,6 @@ lgb.cv <- function(params = list()
end_iteration
<-
begin_iteration
+
nrounds
-
1L
end_iteration
<-
begin_iteration
+
nrounds
-
1L
}
}
# Check for training dataset type correctness
if
(
!
lgb.is.Dataset
(
data
))
{
if
(
is.null
(
label
))
{
stop
(
"Labels must be provided for lgb.cv"
)
}
data
<-
lgb.Dataset
(
data
,
label
=
label
)
}
# Check for weights
# Check for weights
if
(
!
is.null
(
weight
))
{
if
(
!
is.null
(
weight
))
{
data
$
setinfo
(
"weight"
,
weight
)
data
$
setinfo
(
"weight"
,
weight
)
...
...
R-package/R/lgb.importance.R
View file @
d9064282
...
@@ -36,7 +36,7 @@
...
@@ -36,7 +36,7 @@
lgb.importance
<-
function
(
model
,
percentage
=
TRUE
)
{
lgb.importance
<-
function
(
model
,
percentage
=
TRUE
)
{
# Check if model is a lightgbm model
# Check if model is a lightgbm model
if
(
!
inherits
(
model
,
"
lgb.Booster
"
))
{
if
(
!
lgb
.is
.Booster
(
model
))
{
stop
(
"'model' has to be an object of class lgb.Booster"
)
stop
(
"'model' has to be an object of class lgb.Booster"
)
}
}
...
...
R-package/R/lgb.train.R
View file @
d9064282
...
@@ -65,6 +65,23 @@ lgb.train <- function(params = list(),
...
@@ -65,6 +65,23 @@ lgb.train <- function(params = list(),
reset_data
=
FALSE
,
reset_data
=
FALSE
,
...
)
{
...
)
{
# validate inputs early to avoid unnecessary computation
if
(
nrounds
<=
0L
)
{
stop
(
"nrounds should be greater than zero"
)
}
if
(
!
lgb.is.Dataset
(
data
))
{
stop
(
"lgb.train: data must be an lgb.Dataset instance"
)
}
if
(
length
(
valids
)
>
0L
)
{
if
(
!
is.list
(
valids
)
||
!
all
(
vapply
(
valids
,
lgb.is.Dataset
,
logical
(
1L
))))
{
stop
(
"lgb.train: valids must be a list of lgb.Dataset elements"
)
}
evnames
<-
names
(
valids
)
if
(
is.null
(
evnames
)
||
!
all
(
nzchar
(
evnames
)))
{
stop
(
"lgb.train: each element of valids must have a name"
)
}
}
# Setup temporary variables
# Setup temporary variables
additional_params
<-
list
(
...
)
additional_params
<-
list
(
...
)
params
<-
append
(
params
,
additional_params
)
params
<-
append
(
params
,
additional_params
)
...
@@ -74,10 +91,6 @@ lgb.train <- function(params = list(),
...
@@ -74,10 +91,6 @@ lgb.train <- function(params = list(),
fobj
<-
NULL
fobj
<-
NULL
feval
<-
NULL
feval
<-
NULL
if
(
nrounds
<=
0L
)
{
stop
(
"nrounds should be greater than zero"
)
}
# Check for objective (function or not)
# Check for objective (function or not)
if
(
is.function
(
params
$
objective
))
{
if
(
is.function
(
params
$
objective
))
{
fobj
<-
params
$
objective
fobj
<-
params
$
objective
...
@@ -112,30 +125,6 @@ lgb.train <- function(params = list(),
...
@@ -112,30 +125,6 @@ lgb.train <- function(params = list(),
end_iteration
<-
begin_iteration
+
nrounds
-
1L
end_iteration
<-
begin_iteration
+
nrounds
-
1L
}
}
# Check for training dataset type correctness
if
(
!
lgb.is.Dataset
(
data
))
{
stop
(
"lgb.train: data only accepts lgb.Dataset object"
)
}
# Check for validation dataset type correctness
if
(
length
(
valids
)
>
0L
)
{
# One or more validation dataset
# Check for list as input and type correctness by object
if
(
!
is.list
(
valids
)
||
!
all
(
vapply
(
valids
,
lgb.is.Dataset
,
logical
(
1L
))))
{
stop
(
"lgb.train: valids must be a list of lgb.Dataset elements"
)
}
# Attempt to get names
evnames
<-
names
(
valids
)
# Check for names existance
if
(
is.null
(
evnames
)
||
!
all
(
nzchar
(
evnames
)))
{
stop
(
"lgb.train: each element of the valids must have a name tag"
)
}
}
# Update parameters with parsed parameters
# Update parameters with parsed parameters
data
$
update_params
(
params
)
data
$
update_params
(
params
)
...
...
R-package/R/lightgbm.R
View file @
d9064282
#' @name lgb_shared_params
#' @name lgb_shared_params
#' @title Shared parameter docs
#' @title Shared parameter docs
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @description Parameter docs shared by \code{lgb.train}, \code{lgb.cv}, and \code{lightgbm}
#' @param callbacks List of callback functions that are applied at each iteration.
#' @param callbacks list of callback functions
#' @param data a \code{lgb.Dataset} object, used for training
#' List of callback functions that are applied at each iteration.
#' @param data a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
#' may allow you to pass other types of data like \code{matrix} and then separately supply
#' \code{label} as a keyword argument.
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#' @param early_stopping_rounds int. Activates early stopping. Requires at least one validation data
#' and one metric. If there's more than one, will check all of them
#' and one metric. If there's more than one, will check all of them
#' except the training data. Returns the model with (best_iter + early_stopping_rounds).
#' except the training data. Returns the model with (best_iter + early_stopping_rounds).
...
@@ -57,11 +60,14 @@ lightgbm <- function(data,
...
@@ -57,11 +60,14 @@ lightgbm <- function(data,
callbacks
=
list
(),
callbacks
=
list
(),
...
)
{
...
)
{
# Set data to a temporary variable
# validate inputs early to avoid unnecessary computation
dtrain
<-
data
if
(
nrounds
<=
0L
)
{
if
(
nrounds
<=
0L
)
{
stop
(
"nrounds should be greater than zero"
)
stop
(
"nrounds should be greater than zero"
)
}
}
# Set data to a temporary variable
dtrain
<-
data
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if
(
!
lgb.is.Dataset
(
dtrain
))
{
if
(
!
lgb.is.Dataset
(
dtrain
))
{
dtrain
<-
lgb.Dataset
(
data
,
label
=
label
,
weight
=
weight
)
dtrain
<-
lgb.Dataset
(
data
,
label
=
label
,
weight
=
weight
)
...
...
R-package/man/lgb.cv.Rd
View file @
d9064282
...
@@ -31,13 +31,15 @@ lgb.cv(
...
@@ -31,13 +31,15 @@ lgb.cv(
\
arguments
{
\
arguments
{
\
item
{
params
}{
List
of
parameters
}
\
item
{
params
}{
List
of
parameters
}
\
item
{
data
}{
a
\
code
{
lgb
.
Dataset
}
object
,
used
for
training
}
\
item
{
data
}{
a
\
code
{
lgb
.
Dataset
}
object
,
used
for
training
.
Some
functions
,
such
as
\
code
{\
link
{
lgb
.
cv
}},
may
allow
you
to
pass
other
types
of
data
like
\
code
{
matrix
}
and
then
separately
supply
\
code
{
label
}
as
a
keyword
argument
.}
\
item
{
nrounds
}{
number
of
training
rounds
}
\
item
{
nrounds
}{
number
of
training
rounds
}
\
item
{
nfold
}{
the
original
dataset
is
randomly
partitioned
into
\
code
{
nfold
}
equal
size
subsamples
.}
\
item
{
nfold
}{
the
original
dataset
is
randomly
partitioned
into
\
code
{
nfold
}
equal
size
subsamples
.}
\
item
{
label
}{
v
ector
of
response
values
.
Should
be
provided
only
when
data
is
an
R
-
matrix
.
}
\
item
{
label
}{
V
ector
of
labels
,
used
if
\
code
{
data
}
is
not
an
\
code
{\
link
{
lgb
.
Dataset
}}
}
\
item
{
weight
}{
vector
of
response
values
.
If
not
NULL
,
will
set
to
dataset
}
\
item
{
weight
}{
vector
of
response
values
.
If
not
NULL
,
will
set
to
dataset
}
...
...
R-package/man/lgb.train.Rd
View file @
d9064282
...
@@ -26,7 +26,9 @@ lgb.train(
...
@@ -26,7 +26,9 @@ lgb.train(
\
arguments
{
\
arguments
{
\
item
{
params
}{
List
of
parameters
}
\
item
{
params
}{
List
of
parameters
}
\
item
{
data
}{
a
\
code
{
lgb
.
Dataset
}
object
,
used
for
training
}
\
item
{
data
}{
a
\
code
{
lgb
.
Dataset
}
object
,
used
for
training
.
Some
functions
,
such
as
\
code
{\
link
{
lgb
.
cv
}},
may
allow
you
to
pass
other
types
of
data
like
\
code
{
matrix
}
and
then
separately
supply
\
code
{
label
}
as
a
keyword
argument
.}
\
item
{
nrounds
}{
number
of
training
rounds
}
\
item
{
nrounds
}{
number
of
training
rounds
}
...
...
R-package/man/lgb_shared_params.Rd
View file @
d9064282
...
@@ -6,7 +6,9 @@
...
@@ -6,7 +6,9 @@
\arguments{
\arguments{
\item{callbacks}{List of callback functions that are applied at each iteration.}
\item{callbacks}{List of callback functions that are applied at each iteration.}
\item{data}{a \code{lgb.Dataset} object, used for training}
\item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{early_stopping_rounds}{int. Activates early stopping. Requires at least one validation data
\item{early_stopping_rounds}{int. Activates early stopping. Requires at least one validation data
and one metric. If there's more than one, will check all of them
and one metric. If there's more than one, will check all of them
...
...
R-package/man/lightgbm.Rd
View file @
d9064282
...
@@ -20,7 +20,9 @@ lightgbm(
...
@@ -20,7 +20,9 @@ lightgbm(
)
)
}
}
\arguments{
\arguments{
\item{data}{a \code{lgb.Dataset} object, used for training}
\item{data}{a \code{lgb.Dataset} object, used for training. Some functions, such as \code{\link{lgb.cv}},
may allow you to pass other types of data like \code{matrix} and then separately supply
\code{label} as a keyword argument.}
\item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}}
\item{label}{Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}}
...
...
R-package/tests/testthat/test_basic.R
View file @
d9064282
context
(
"
basic functions
"
)
context
(
"
lightgbm()
"
)
data
(
agaricus.train
,
package
=
"lightgbm"
)
data
(
agaricus.train
,
package
=
"lightgbm"
)
data
(
agaricus.test
,
package
=
"lightgbm"
)
data
(
agaricus.test
,
package
=
"lightgbm"
)
...
@@ -70,6 +70,20 @@ test_that("use of multiple eval metrics works", {
...
@@ -70,6 +70,20 @@ test_that("use of multiple eval metrics works", {
expect_false
(
is.null
(
bst
$
record_evals
))
expect_false
(
is.null
(
bst
$
record_evals
))
})
})
test_that
(
"lightgbm() rejects negative or 0 value passed to nrounds"
,
{
dtrain
<-
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
params
<-
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
for
(
nround_value
in
c
(
-10L
,
0L
))
{
expect_error
({
bst
<-
lightgbm
(
data
=
dtrain
,
params
=
params
,
nrounds
=
nround_value
)
},
"nrounds should be greater than zero"
)
}
})
test_that
(
"training continuation works"
,
{
test_that
(
"training continuation works"
,
{
testthat
::
skip
(
"This test is currently broken. See issue #2468 for details."
)
testthat
::
skip
(
"This test is currently broken. See issue #2468 for details."
)
...
@@ -103,6 +117,7 @@ test_that("training continuation works", {
...
@@ -103,6 +117,7 @@ test_that("training continuation works", {
expect_lt
(
abs
(
err_bst
-
err_bst2
),
0.01
)
expect_lt
(
abs
(
err_bst
-
err_bst2
),
0.01
)
})
})
context
(
"lgb.cv()"
)
test_that
(
"cv works"
,
{
test_that
(
"cv works"
,
{
dtrain
<-
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
dtrain
<-
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
...
@@ -118,3 +133,122 @@ test_that("cv works", {
...
@@ -118,3 +133,122 @@ test_that("cv works", {
)
)
expect_false
(
is.null
(
bst
$
record_evals
))
expect_false
(
is.null
(
bst
$
record_evals
))
})
})
test_that
(
"lgb.cv() rejects negative or 0 value passed to nrounds"
,
{
dtrain
<-
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
params
<-
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
for
(
nround_value
in
c
(
-10L
,
0L
))
{
expect_error
({
bst
<-
lgb.cv
(
params
,
dtrain
,
nround_value
,
nfold
=
5L
,
min_data
=
1L
)
},
"nrounds should be greater than zero"
)
}
})
test_that
(
"lgb.cv() throws an informative error is 'data' is not an lgb.Dataset and labels are not given"
,
{
bad_values
<-
list
(
4L
,
"hello"
,
list
(
a
=
TRUE
,
b
=
seq_len
(
10L
))
,
data.frame
(
x
=
seq_len
(
5L
),
y
=
seq_len
(
5L
))
,
data.table
::
data.table
(
x
=
seq_len
(
5L
),
y
=
seq_len
(
5L
))
,
matrix
(
data
=
seq_len
(
10L
),
2L
,
5L
)
)
for
(
val
in
bad_values
)
{
expect_error
({
bst
<-
lgb.cv
(
params
=
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
,
data
=
val
,
10L
,
nfold
=
5L
,
min_data
=
1L
)
},
regexp
=
"'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'"
,
fixed
=
TRUE
)
}
})
context
(
"lgb.train()"
)
test_that
(
"lgb.train() rejects negative or 0 value passed to nrounds"
,
{
dtrain
<-
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
params
<-
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
for
(
nround_value
in
c
(
-10L
,
0L
))
{
expect_error
({
bst
<-
lgb.train
(
params
,
dtrain
,
nround_value
)
},
"nrounds should be greater than zero"
)
}
})
test_that
(
"lgb.train() throws an informative error if 'data' is not an lgb.Dataset"
,
{
bad_values
<-
list
(
4L
,
"hello"
,
list
(
a
=
TRUE
,
b
=
seq_len
(
10L
))
,
data.frame
(
x
=
seq_len
(
5L
),
y
=
seq_len
(
5L
))
,
data.table
::
data.table
(
x
=
seq_len
(
5L
),
y
=
seq_len
(
5L
))
,
matrix
(
data
=
seq_len
(
10L
),
2L
,
5L
)
)
for
(
val
in
bad_values
)
{
expect_error
({
bst
<-
lgb.train
(
params
=
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
,
data
=
val
,
10L
)
},
regexp
=
"data must be an lgb.Dataset instance"
,
fixed
=
TRUE
)
}
})
test_that
(
"lgb.train() throws an informative error if 'valids' is not a list of lgb.Dataset objects"
,
{
valids
<-
list
(
"valid1"
=
data.frame
(
x
=
rnorm
(
5L
),
y
=
rnorm
(
5L
))
,
"valid2"
=
data.frame
(
x
=
rnorm
(
5L
),
y
=
rnorm
(
5L
))
)
expect_error
({
bst
<-
lgb.train
(
params
=
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
,
data
=
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
,
10L
,
valids
=
valids
)
},
regexp
=
"valids must be a list of lgb.Dataset elements"
)
})
test_that
(
"lgb.train() errors if 'valids' is a list of lgb.Dataset objects but some do not have names"
,
{
valids
<-
list
(
"valid1"
=
lgb.Dataset
(
matrix
(
rnorm
(
10L
),
5L
,
2L
))
,
lgb.Dataset
(
matrix
(
rnorm
(
10L
),
2L
,
5L
))
)
expect_error
({
bst
<-
lgb.train
(
params
=
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
,
data
=
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
,
10L
,
valids
=
valids
)
},
regexp
=
"each element of valids must have a name"
)
})
test_that
(
"lgb.train() throws an informative error if 'valids' contains lgb.Dataset objects but none have names"
,
{
valids
<-
list
(
lgb.Dataset
(
matrix
(
rnorm
(
10L
),
5L
,
2L
))
,
lgb.Dataset
(
matrix
(
rnorm
(
10L
),
2L
,
5L
))
)
expect_error
({
bst
<-
lgb.train
(
params
=
list
(
objective
=
"regression"
,
metric
=
"l2,l1"
)
,
data
=
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
,
10L
,
valids
=
valids
)
},
regexp
=
"each element of valids must have a name"
)
})
R-package/tests/testthat/test_dataset.R
View file @
d9064282
...
@@ -99,3 +99,30 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", {
...
@@ -99,3 +99,30 @@ test_that("lgb.Dataset$setinfo() should convert 'group' to integer", {
ds
$
setinfo
(
"group"
,
group_as_numeric
)
ds
$
setinfo
(
"group"
,
group_as_numeric
)
expect_identical
(
ds
$
getinfo
(
"group"
),
as.integer
(
group_as_numeric
))
expect_identical
(
ds
$
getinfo
(
"group"
),
as.integer
(
group_as_numeric
))
})
})
test_that
(
"lgb.Dataset should throw an error if 'reference' is provided but of the wrong format"
,
{
data
(
agaricus.test
,
package
=
"lightgbm"
)
test_data
<-
agaricus.test
$
data
[
1L
:
100L
,
]
test_label
<-
agaricus.test
$
label
[
1L
:
100L
]
# Try to trick lgb.Dataset() into accepting bad input
expect_error
({
dtest
<-
lgb.Dataset
(
data
=
test_data
,
label
=
test_label
,
reference
=
data.frame
(
x
=
seq_len
(
10L
),
y
=
seq_len
(
10L
))
)
},
regexp
=
"reference must be a"
)
})
test_that
(
"Dataset$new() should throw an error if 'predictor' is provided but of the wrong format"
,
{
data
(
agaricus.test
,
package
=
"lightgbm"
)
test_data
<-
agaricus.test
$
data
[
1L
:
100L
,
]
test_label
<-
agaricus.test
$
label
[
1L
:
100L
]
expect_error
({
dtest
<-
Dataset
$
new
(
data
=
test_data
,
label
=
test_label
,
predictor
=
data.frame
(
x
=
seq_len
(
10L
),
y
=
seq_len
(
10L
))
)
},
regexp
=
"predictor must be a"
,
fixed
=
TRUE
)
})
R-package/tests/testthat/test_utils.R
0 → 100644
View file @
d9064282
context
(
"lgb.check.r6.class"
)
test_that
(
"lgb.check.r6.class() should return FALSE for NULL input"
,
{
expect_false
(
lgb.check.r6.class
(
NULL
,
"lgb.Dataset"
))
})
test_that
(
"lgb.check.r6.class() should return FALSE for non-R6 inputs"
,
{
x
<-
5L
class
(
x
)
<-
"lgb.Dataset"
expect_false
(
lgb.check.r6.class
(
x
,
"lgb.Dataset"
))
})
test_that
(
"lgb.check.r6.class() should correctly identify lgb.Dataset"
,
{
data
(
"agaricus.train"
,
package
=
"lightgbm"
)
train
<-
agaricus.train
ds
<-
lgb.Dataset
(
train
$
data
,
label
=
train
$
label
)
expect_true
(
lgb.check.r6.class
(
ds
,
"lgb.Dataset"
))
expect_false
(
lgb.check.r6.class
(
ds
,
"lgb.Predictor"
))
expect_false
(
lgb.check.r6.class
(
ds
,
"lgb.Booster"
))
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment