Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
694e41e4
Unverified
Commit
694e41e4
authored
Nov 13, 2023
by
James Lamb
Committed by
GitHub
Nov 13, 2023
Browse files
[R-package] standardize naming of internal functions (#6179)
parent
deb70773
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
149 additions
and
140 deletions
+149
-140
R-package/R/callback.R
R-package/R/callback.R
+4
-4
R-package/R/lgb.Booster.R
R-package/R/lgb.Booster.R
+17
-17
R-package/R/lgb.DataProcessor.R
R-package/R/lgb.DataProcessor.R
+1
-1
R-package/R/lgb.Dataset.R
R-package/R/lgb.Dataset.R
+27
-27
R-package/R/lgb.Predictor.R
R-package/R/lgb.Predictor.R
+5
-5
R-package/R/lgb.cv.R
R-package/R/lgb.cv.R
+19
-19
R-package/R/lgb.drop_serialized.R
R-package/R/lgb.drop_serialized.R
+1
-1
R-package/R/lgb.importance.R
R-package/R/lgb.importance.R
+1
-1
R-package/R/lgb.interprete.R
R-package/R/lgb.interprete.R
+4
-4
R-package/R/lgb.make_serializable.R
R-package/R/lgb.make_serializable.R
+1
-1
R-package/R/lgb.model.dt.tree.R
R-package/R/lgb.model.dt.tree.R
+5
-2
R-package/R/lgb.plot.interpretation.R
R-package/R/lgb.plot.interpretation.R
+3
-3
R-package/R/lgb.restore_handle.R
R-package/R/lgb.restore_handle.R
+1
-1
R-package/R/lgb.train.R
R-package/R/lgb.train.R
+20
-14
R-package/R/lightgbm.R
R-package/R/lightgbm.R
+5
-5
R-package/R/saveRDS.lgb.Booster.R
R-package/R/saveRDS.lgb.Booster.R
+1
-1
R-package/R/utils.R
R-package/R/utils.R
+10
-10
R-package/tests/testthat/test_Predictor.R
R-package/tests/testthat/test_Predictor.R
+5
-5
R-package/tests/testthat/test_basic.R
R-package/tests/testthat/test_basic.R
+11
-11
R-package/tests/testthat/test_dataset.R
R-package/tests/testthat/test_dataset.R
+8
-8
No files found.
R-package/R/callback.R
View file @
694e41e4
...
@@ -323,17 +323,17 @@ cb_early_stop <- function(stopping_rounds, first_metric_only, verbose) {
...
@@ -323,17 +323,17 @@ cb_early_stop <- function(stopping_rounds, first_metric_only, verbose) {
}
}
# Extract callback names from the list of callbacks
# Extract callback names from the list of callbacks
callback
.
names
<-
function
(
cb_list
)
{
.
callback
_
names
<-
function
(
cb_list
)
{
return
(
unlist
(
lapply
(
cb_list
,
attr
,
"name"
)))
return
(
unlist
(
lapply
(
cb_list
,
attr
,
"name"
)))
}
}
add
.
cb
<-
function
(
cb_list
,
cb
)
{
.
add
_
cb
<-
function
(
cb_list
,
cb
)
{
# Combine two elements
# Combine two elements
cb_list
<-
c
(
cb_list
,
cb
)
cb_list
<-
c
(
cb_list
,
cb
)
# Set names of elements
# Set names of elements
names
(
cb_list
)
<-
callback
.
names
(
cb_list
=
cb_list
)
names
(
cb_list
)
<-
.
callback
_
names
(
cb_list
=
cb_list
)
if
(
"cb_early_stop"
%in%
names
(
cb_list
))
{
if
(
"cb_early_stop"
%in%
names
(
cb_list
))
{
...
@@ -349,7 +349,7 @@ add.cb <- function(cb_list, cb) {
...
@@ -349,7 +349,7 @@ add.cb <- function(cb_list, cb) {
}
}
categorize
.
callbacks
<-
function
(
cb_list
)
{
.
categorize
_
callbacks
<-
function
(
cb_list
)
{
# Check for pre-iteration or post-iteration
# Check for pre-iteration or post-iteration
return
(
return
(
...
...
R-package/R/lgb.Booster.R
View file @
694e41e4
...
@@ -31,12 +31,12 @@ Booster <- R6::R6Class(
...
@@ -31,12 +31,12 @@ Booster <- R6::R6Class(
if
(
!
is.null
(
train_set
))
{
if
(
!
is.null
(
train_set
))
{
if
(
!
lgb
.is
.
Dataset
(
train_set
))
{
if
(
!
.is
_
Dataset
(
train_set
))
{
stop
(
"lgb.Booster: Can only use lgb.Dataset as training data"
)
stop
(
"lgb.Booster: Can only use lgb.Dataset as training data"
)
}
}
train_set_handle
<-
train_set
$
.__enclos_env__
$
private
$
get_handle
()
train_set_handle
<-
train_set
$
.__enclos_env__
$
private
$
get_handle
()
params
<-
utils
::
modifyList
(
params
,
train_set
$
get_params
())
params
<-
utils
::
modifyList
(
params
,
train_set
$
get_params
())
params_str
<-
lgb
.params2str
(
params
=
params
)
params_str
<-
.params2str
(
params
=
params
)
# Store booster handle
# Store booster handle
handle
<-
.Call
(
handle
<-
.Call
(
LGBM_BoosterCreate_R
LGBM_BoosterCreate_R
...
@@ -130,7 +130,7 @@ Booster <- R6::R6Class(
...
@@ -130,7 +130,7 @@ Booster <- R6::R6Class(
# Add validation data
# Add validation data
add_valid
=
function
(
data
,
name
)
{
add_valid
=
function
(
data
,
name
)
{
if
(
!
lgb
.is
.
Dataset
(
data
))
{
if
(
!
.is
_
Dataset
(
data
))
{
stop
(
"lgb.Booster.add_valid: Can only use lgb.Dataset as validation data"
)
stop
(
"lgb.Booster.add_valid: Can only use lgb.Dataset as validation data"
)
}
}
...
@@ -167,7 +167,7 @@ Booster <- R6::R6Class(
...
@@ -167,7 +167,7 @@ Booster <- R6::R6Class(
params
<-
utils
::
modifyList
(
self
$
params
,
params
)
params
<-
utils
::
modifyList
(
self
$
params
,
params
)
}
}
params_str
<-
lgb
.params2str
(
params
=
params
)
params_str
<-
.params2str
(
params
=
params
)
self
$
restore_handle
()
self
$
restore_handle
()
...
@@ -193,7 +193,7 @@ Booster <- R6::R6Class(
...
@@ -193,7 +193,7 @@ Booster <- R6::R6Class(
if
(
!
is.null
(
train_set
))
{
if
(
!
is.null
(
train_set
))
{
if
(
!
lgb
.is
.
Dataset
(
train_set
))
{
if
(
!
.is
_
Dataset
(
train_set
))
{
stop
(
"lgb.Booster.update: Only can use lgb.Dataset as training data"
)
stop
(
"lgb.Booster.update: Only can use lgb.Dataset as training data"
)
}
}
...
@@ -340,7 +340,7 @@ Booster <- R6::R6Class(
...
@@ -340,7 +340,7 @@ Booster <- R6::R6Class(
# Evaluate data on metrics
# Evaluate data on metrics
eval
=
function
(
data
,
name
,
feval
=
NULL
)
{
eval
=
function
(
data
,
name
,
feval
=
NULL
)
{
if
(
!
lgb
.is
.
Dataset
(
data
))
{
if
(
!
.is
_
Dataset
(
data
))
{
stop
(
"lgb.Booster.eval: Can only use lgb.Dataset to eval"
)
stop
(
"lgb.Booster.eval: Can only use lgb.Dataset to eval"
)
}
}
...
@@ -508,17 +508,17 @@ Booster <- R6::R6Class(
...
@@ -508,17 +508,17 @@ Booster <- R6::R6Class(
# NOTE: doing this here instead of in Predictor$predict() to keep
# NOTE: doing this here instead of in Predictor$predict() to keep
# Predictor$predict() as fast as possible
# Predictor$predict() as fast as possible
if
(
length
(
params
)
>
0L
)
{
if
(
length
(
params
)
>
0L
)
{
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"predict_raw_score"
main_param_name
=
"predict_raw_score"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
rawscore
,
alternative_kwarg_value
=
rawscore
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"predict_leaf_index"
main_param_name
=
"predict_leaf_index"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
predleaf
,
alternative_kwarg_value
=
predleaf
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"predict_contrib"
main_param_name
=
"predict_contrib"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
predcontrib
,
alternative_kwarg_value
=
predcontrib
...
@@ -586,7 +586,7 @@ Booster <- R6::R6Class(
...
@@ -586,7 +586,7 @@ Booster <- R6::R6Class(
,
predcontrib
,
predcontrib
,
start_iteration
,
start_iteration
,
num_iteration
,
num_iteration
,
lgb
.params2str
(
params
=
params
)
,
.params2str
(
params
=
params
)
)
)
private
$
fast_predict_config
<-
list
(
private
$
fast_predict_config
<-
list
(
...
@@ -622,7 +622,7 @@ Booster <- R6::R6Class(
...
@@ -622,7 +622,7 @@ Booster <- R6::R6Class(
},
},
check_null_handle
=
function
()
{
check_null_handle
=
function
()
{
return
(
lgb
.is
.
null
.
handle
(
private
$
handle
))
return
(
.is
_
null
_
handle
(
private
$
handle
))
},
},
restore_handle
=
function
()
{
restore_handle
=
function
()
{
...
@@ -959,7 +959,7 @@ predict.lgb.Booster <- function(object,
...
@@ -959,7 +959,7 @@ predict.lgb.Booster <- function(object,
params
=
list
(),
params
=
list
(),
...
)
{
...
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
object
))
{
if
(
!
.is
_
Booster
(
x
=
object
))
{
stop
(
"predict.lgb.Booster: object should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"predict.lgb.Booster: object should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
...
@@ -1114,7 +1114,7 @@ lgb.configure_fast_predict <- function(model,
...
@@ -1114,7 +1114,7 @@ lgb.configure_fast_predict <- function(model,
num_iteration
=
NULL
,
num_iteration
=
NULL
,
type
=
"response"
,
type
=
"response"
,
params
=
list
())
{
params
=
list
())
{
if
(
!
lgb
.is
.
Booster
(
x
=
model
))
{
if
(
!
.is
_
Booster
(
x
=
model
))
{
stop
(
"lgb.configure_fast_predict: model should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"lgb.configure_fast_predict: model should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
if
(
type
==
"class"
)
{
if
(
type
==
"class"
)
{
...
@@ -1160,7 +1160,7 @@ lgb.configure_fast_predict <- function(model,
...
@@ -1160,7 +1160,7 @@ lgb.configure_fast_predict <- function(model,
print.lgb.Booster
<-
function
(
x
,
...
)
{
print.lgb.Booster
<-
function
(
x
,
...
)
{
# nolint start
# nolint start
handle
<-
x
$
.__enclos_env__
$
private
$
handle
handle
<-
x
$
.__enclos_env__
$
private
$
handle
handle_is_null
<-
lgb
.is
.
null
.
handle
(
handle
)
handle_is_null
<-
.is
_
null
_
handle
(
handle
)
if
(
!
handle_is_null
)
{
if
(
!
handle_is_null
)
{
ntrees
<-
x
$
current_iter
()
ntrees
<-
x
$
current_iter
()
...
@@ -1316,7 +1316,7 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
...
@@ -1316,7 +1316,7 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
#' @export
#' @export
lgb.save
<-
function
(
booster
,
filename
,
num_iteration
=
NULL
)
{
lgb.save
<-
function
(
booster
,
filename
,
num_iteration
=
NULL
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
booster
))
{
if
(
!
.is
_
Booster
(
x
=
booster
))
{
stop
(
"lgb.save: booster should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"lgb.save: booster should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
...
@@ -1372,7 +1372,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
...
@@ -1372,7 +1372,7 @@ lgb.save <- function(booster, filename, num_iteration = NULL) {
#' @export
#' @export
lgb.dump
<-
function
(
booster
,
num_iteration
=
NULL
)
{
lgb.dump
<-
function
(
booster
,
num_iteration
=
NULL
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
booster
))
{
if
(
!
.is
_
Booster
(
x
=
booster
))
{
stop
(
"lgb.dump: booster should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"lgb.dump: booster should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
...
@@ -1430,7 +1430,7 @@ lgb.dump <- function(booster, num_iteration = NULL) {
...
@@ -1430,7 +1430,7 @@ lgb.dump <- function(booster, num_iteration = NULL) {
#' @export
#' @export
lgb.get.eval.result
<-
function
(
booster
,
data_name
,
eval_name
,
iters
=
NULL
,
is_err
=
FALSE
)
{
lgb.get.eval.result
<-
function
(
booster
,
data_name
,
eval_name
,
iters
=
NULL
,
is_err
=
FALSE
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
booster
))
{
if
(
!
.is
_
Booster
(
x
=
booster
))
{
stop
(
"lgb.get.eval.result: Can only use "
,
sQuote
(
"lgb.Booster"
),
" to get eval result"
)
stop
(
"lgb.get.eval.result: Can only use "
,
sQuote
(
"lgb.Booster"
),
" to get eval result"
)
}
}
...
...
R-package/R/lgb.DataProcessor.R
View file @
694e41e4
...
@@ -39,7 +39,7 @@ DataProcessor <- R6::R6Class(
...
@@ -39,7 +39,7 @@ DataProcessor <- R6::R6Class(
)
)
}
}
data_num_class
<-
length
(
self
$
factor_levels
)
data_num_class
<-
length
(
self
$
factor_levels
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"num_class"
main_param_name
=
"num_class"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
data_num_class
,
alternative_kwarg_value
=
data_num_class
...
...
R-package/R/lgb.Dataset.R
View file @
694e41e4
...
@@ -55,10 +55,10 @@ Dataset <- R6::R6Class(
...
@@ -55,10 +55,10 @@ Dataset <- R6::R6Class(
init_score
=
NULL
)
{
init_score
=
NULL
)
{
# validate inputs early to avoid unnecessary computation
# validate inputs early to avoid unnecessary computation
if
(
!
(
is.null
(
reference
)
||
lgb
.is
.
Dataset
(
reference
)))
{
if
(
!
(
is.null
(
reference
)
||
.is
_
Dataset
(
reference
)))
{
stop
(
"lgb.Dataset: If provided, reference must be a "
,
sQuote
(
"lgb.Dataset"
))
stop
(
"lgb.Dataset: If provided, reference must be a "
,
sQuote
(
"lgb.Dataset"
))
}
}
if
(
!
(
is.null
(
predictor
)
||
lgb
.is
.
Predictor
(
predictor
)))
{
if
(
!
(
is.null
(
predictor
)
||
.is
_
Predictor
(
predictor
)))
{
stop
(
"lgb.Dataset: If provided, predictor must be a "
,
sQuote
(
"lgb.Predictor"
))
stop
(
"lgb.Dataset: If provided, predictor must be a "
,
sQuote
(
"lgb.Predictor"
))
}
}
...
@@ -135,7 +135,7 @@ Dataset <- R6::R6Class(
...
@@ -135,7 +135,7 @@ Dataset <- R6::R6Class(
construct
=
function
()
{
construct
=
function
()
{
# Check for handle null
# Check for handle null
if
(
!
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
!
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
return
(
invisible
(
self
))
return
(
invisible
(
self
))
}
}
...
@@ -191,7 +191,7 @@ Dataset <- R6::R6Class(
...
@@ -191,7 +191,7 @@ Dataset <- R6::R6Class(
}
}
# Generate parameter str
# Generate parameter str
params_str
<-
lgb
.params2str
(
params
=
private
$
params
)
params_str
<-
.params2str
(
params
=
private
$
params
)
# Get handle of reference dataset
# Get handle of reference dataset
ref_handle
<-
NULL
ref_handle
<-
NULL
...
@@ -277,7 +277,7 @@ Dataset <- R6::R6Class(
...
@@ -277,7 +277,7 @@ Dataset <- R6::R6Class(
)
)
}
}
if
(
lgb
.is
.
null
.
handle
(
x
=
handle
))
{
if
(
.is
_
null
_
handle
(
x
=
handle
))
{
stop
(
"lgb.Dataset.construct: cannot create Dataset handle"
)
stop
(
"lgb.Dataset.construct: cannot create Dataset handle"
)
}
}
# Setup class and private type
# Setup class and private type
...
@@ -345,7 +345,7 @@ Dataset <- R6::R6Class(
...
@@ -345,7 +345,7 @@ Dataset <- R6::R6Class(
dim
=
function
()
{
dim
=
function
()
{
# Check for handle
# Check for handle
if
(
!
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
!
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
num_row
<-
0L
num_row
<-
0L
num_col
<-
0L
num_col
<-
0L
...
@@ -385,7 +385,7 @@ Dataset <- R6::R6Class(
...
@@ -385,7 +385,7 @@ Dataset <- R6::R6Class(
# Get number of bins for feature
# Get number of bins for feature
get_feature_num_bin
=
function
(
feature
)
{
get_feature_num_bin
=
function
(
feature
)
{
if
(
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
stop
(
"Cannot get number of bins in feature before constructing Dataset."
)
stop
(
"Cannot get number of bins in feature before constructing Dataset."
)
}
}
if
(
is.character
(
feature
))
{
if
(
is.character
(
feature
))
{
...
@@ -409,7 +409,7 @@ Dataset <- R6::R6Class(
...
@@ -409,7 +409,7 @@ Dataset <- R6::R6Class(
get_colnames
=
function
()
{
get_colnames
=
function
()
{
# Check for handle
# Check for handle
if
(
!
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
!
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
private
$
colnames
<-
.Call
(
private
$
colnames
<-
.Call
(
LGBM_DatasetGetFeatureNames_R
LGBM_DatasetGetFeatureNames_R
,
private
$
handle
,
private
$
handle
...
@@ -449,7 +449,7 @@ Dataset <- R6::R6Class(
...
@@ -449,7 +449,7 @@ Dataset <- R6::R6Class(
# Write column names
# Write column names
private
$
colnames
<-
colnames
private
$
colnames
<-
colnames
if
(
!
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
!
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
# Merge names with tab separation
# Merge names with tab separation
merged_name
<-
paste0
(
as.list
(
private
$
colnames
),
collapse
=
"\t"
)
merged_name
<-
paste0
(
as.list
(
private
$
colnames
),
collapse
=
"\t"
)
...
@@ -478,7 +478,7 @@ Dataset <- R6::R6Class(
...
@@ -478,7 +478,7 @@ Dataset <- R6::R6Class(
# Check for info name and handle
# Check for info name and handle
if
(
is.null
(
private
$
info
[[
field_name
]]))
{
if
(
is.null
(
private
$
info
[[
field_name
]]))
{
if
(
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
stop
(
"Cannot perform Dataset$get_field() before constructing Dataset."
)
stop
(
"Cannot perform Dataset$get_field() before constructing Dataset."
)
}
}
...
@@ -536,7 +536,7 @@ Dataset <- R6::R6Class(
...
@@ -536,7 +536,7 @@ Dataset <- R6::R6Class(
# Store information privately
# Store information privately
private
$
info
[[
field_name
]]
<-
data
private
$
info
[[
field_name
]]
<-
data
if
(
!
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
)
&&
!
is.null
(
data
))
{
if
(
!
.is
_
null
_
handle
(
x
=
private
$
handle
)
&&
!
is.null
(
data
))
{
if
(
length
(
data
)
>
0L
)
{
if
(
length
(
data
)
>
0L
)
{
...
@@ -583,14 +583,14 @@ Dataset <- R6::R6Class(
...
@@ -583,14 +583,14 @@ Dataset <- R6::R6Class(
return
(
invisible
(
self
))
return
(
invisible
(
self
))
}
}
new_params
<-
utils
::
modifyList
(
private
$
params
,
params
)
new_params
<-
utils
::
modifyList
(
private
$
params
,
params
)
if
(
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
private
$
params
<-
new_params
private
$
params
<-
new_params
}
else
{
}
else
{
tryCatch
({
tryCatch
({
.Call
(
.Call
(
LGBM_DatasetUpdateParamChecking_R
LGBM_DatasetUpdateParamChecking_R
,
lgb
.params2str
(
params
=
private
$
params
)
,
.params2str
(
params
=
private
$
params
)
,
lgb
.params2str
(
params
=
new_params
)
,
.params2str
(
params
=
new_params
)
)
)
private
$
params
<-
new_params
private
$
params
<-
new_params
},
error
=
function
(
e
)
{
},
error
=
function
(
e
)
{
...
@@ -663,7 +663,7 @@ Dataset <- R6::R6Class(
...
@@ -663,7 +663,7 @@ Dataset <- R6::R6Class(
please set "
,
sQuote
(
"free_raw_data = FALSE"
),
" when you construct lgb.Dataset"
)
please set "
,
sQuote
(
"free_raw_data = FALSE"
),
" when you construct lgb.Dataset"
)
}
}
if
(
!
lgb
.is
.
Dataset
(
reference
))
{
if
(
!
.is
_
Dataset
(
reference
))
{
stop
(
"set_reference: Can only use lgb.Dataset as a reference"
)
stop
(
"set_reference: Can only use lgb.Dataset as a reference"
)
}
}
...
@@ -711,7 +711,7 @@ Dataset <- R6::R6Class(
...
@@ -711,7 +711,7 @@ Dataset <- R6::R6Class(
get_handle
=
function
()
{
get_handle
=
function
()
{
# Get handle and construct if needed
# Get handle and construct if needed
if
(
lgb
.is
.
null
.
handle
(
x
=
private
$
handle
))
{
if
(
.is
_
null
_
handle
(
x
=
private
$
handle
))
{
self
$
construct
()
self
$
construct
()
}
}
return
(
private
$
handle
)
return
(
private
$
handle
)
...
@@ -734,7 +734,7 @@ Dataset <- R6::R6Class(
...
@@ -734,7 +734,7 @@ Dataset <- R6::R6Class(
if
(
!
is.null
(
predictor
))
{
if
(
!
is.null
(
predictor
))
{
# Predictor is unknown
# Predictor is unknown
if
(
!
lgb
.is
.
Predictor
(
predictor
))
{
if
(
!
.is
_
Predictor
(
predictor
))
{
stop
(
"set_predictor: Can only use lgb.Predictor as predictor"
)
stop
(
"set_predictor: Can only use lgb.Predictor as predictor"
)
}
}
...
@@ -888,7 +888,7 @@ lgb.Dataset.create.valid <- function(dataset,
...
@@ -888,7 +888,7 @@ lgb.Dataset.create.valid <- function(dataset,
init_score
=
NULL
,
init_score
=
NULL
,
params
=
list
())
{
params
=
list
())
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"lgb.Dataset.create.valid: input data should be an lgb.Dataset object"
)
stop
(
"lgb.Dataset.create.valid: input data should be an lgb.Dataset object"
)
}
}
...
@@ -922,7 +922,7 @@ lgb.Dataset.create.valid <- function(dataset,
...
@@ -922,7 +922,7 @@ lgb.Dataset.create.valid <- function(dataset,
#' @export
#' @export
lgb.Dataset.construct
<-
function
(
dataset
)
{
lgb.Dataset.construct
<-
function
(
dataset
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"lgb.Dataset.construct: input data should be an lgb.Dataset object"
)
stop
(
"lgb.Dataset.construct: input data should be an lgb.Dataset object"
)
}
}
...
@@ -954,7 +954,7 @@ lgb.Dataset.construct <- function(dataset) {
...
@@ -954,7 +954,7 @@ lgb.Dataset.construct <- function(dataset) {
#' @export
#' @export
dim.lgb.Dataset
<-
function
(
x
)
{
dim.lgb.Dataset
<-
function
(
x
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
x
))
{
if
(
!
.is
_
Dataset
(
x
=
x
))
{
stop
(
"dim.lgb.Dataset: input data should be an lgb.Dataset object"
)
stop
(
"dim.lgb.Dataset: input data should be an lgb.Dataset object"
)
}
}
...
@@ -989,7 +989,7 @@ dim.lgb.Dataset <- function(x) {
...
@@ -989,7 +989,7 @@ dim.lgb.Dataset <- function(x) {
#' @export
#' @export
dimnames.lgb.Dataset
<-
function
(
x
)
{
dimnames.lgb.Dataset
<-
function
(
x
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
x
))
{
if
(
!
.is
_
Dataset
(
x
=
x
))
{
stop
(
"dimnames.lgb.Dataset: input data should be an lgb.Dataset object"
)
stop
(
"dimnames.lgb.Dataset: input data should be an lgb.Dataset object"
)
}
}
...
@@ -1062,7 +1062,7 @@ slice <- function(dataset, idxset) {
...
@@ -1062,7 +1062,7 @@ slice <- function(dataset, idxset) {
#' @export
#' @export
slice.lgb.Dataset
<-
function
(
dataset
,
idxset
)
{
slice.lgb.Dataset
<-
function
(
dataset
,
idxset
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"slice.lgb.Dataset: input dataset should be an lgb.Dataset object"
)
stop
(
"slice.lgb.Dataset: input dataset should be an lgb.Dataset object"
)
}
}
...
@@ -1110,7 +1110,7 @@ get_field <- function(dataset, field_name) {
...
@@ -1110,7 +1110,7 @@ get_field <- function(dataset, field_name) {
get_field.lgb.Dataset
<-
function
(
dataset
,
field_name
)
{
get_field.lgb.Dataset
<-
function
(
dataset
,
field_name
)
{
# Check if dataset is not a dataset
# Check if dataset is not a dataset
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"get_field.lgb.Dataset(): input dataset should be an lgb.Dataset object"
)
stop
(
"get_field.lgb.Dataset(): input dataset should be an lgb.Dataset object"
)
}
}
...
@@ -1158,7 +1158,7 @@ set_field <- function(dataset, field_name, data) {
...
@@ -1158,7 +1158,7 @@ set_field <- function(dataset, field_name, data) {
#' @export
#' @export
set_field.lgb.Dataset
<-
function
(
dataset
,
field_name
,
data
)
{
set_field.lgb.Dataset
<-
function
(
dataset
,
field_name
,
data
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"set_field.lgb.Dataset: input dataset should be an lgb.Dataset object"
)
stop
(
"set_field.lgb.Dataset: input dataset should be an lgb.Dataset object"
)
}
}
...
@@ -1189,7 +1189,7 @@ set_field.lgb.Dataset <- function(dataset, field_name, data) {
...
@@ -1189,7 +1189,7 @@ set_field.lgb.Dataset <- function(dataset, field_name, data) {
#' @export
#' @export
lgb.Dataset.set.categorical
<-
function
(
dataset
,
categorical_feature
)
{
lgb.Dataset.set.categorical
<-
function
(
dataset
,
categorical_feature
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object"
)
stop
(
"lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object"
)
}
}
...
@@ -1222,7 +1222,7 @@ lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
...
@@ -1222,7 +1222,7 @@ lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
#' @export
#' @export
lgb.Dataset.set.reference
<-
function
(
dataset
,
reference
)
{
lgb.Dataset.set.reference
<-
function
(
dataset
,
reference
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object"
)
stop
(
"lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object"
)
}
}
...
@@ -1248,7 +1248,7 @@ lgb.Dataset.set.reference <- function(dataset, reference) {
...
@@ -1248,7 +1248,7 @@ lgb.Dataset.set.reference <- function(dataset, reference) {
#' @export
#' @export
lgb.Dataset.save
<-
function
(
dataset
,
fname
)
{
lgb.Dataset.save
<-
function
(
dataset
,
fname
)
{
if
(
!
lgb
.is
.
Dataset
(
x
=
dataset
))
{
if
(
!
.is
_
Dataset
(
x
=
dataset
))
{
stop
(
"lgb.Dataset.save: input dataset should be an lgb.Dataset object"
)
stop
(
"lgb.Dataset.save: input dataset should be an lgb.Dataset object"
)
}
}
...
...
R-package/R/lgb.Predictor.R
View file @
694e41e4
...
@@ -28,7 +28,7 @@ Predictor <- R6::R6Class(
...
@@ -28,7 +28,7 @@ Predictor <- R6::R6Class(
# Initialize will create a starter model
# Initialize will create a starter model
initialize
=
function
(
modelfile
,
params
=
list
(),
fast_predict_config
=
list
())
{
initialize
=
function
(
modelfile
,
params
=
list
(),
fast_predict_config
=
list
())
{
private
$
params
<-
lgb
.params2str
(
params
=
params
)
private
$
params
<-
.params2str
(
params
=
params
)
handle
<-
NULL
handle
<-
NULL
if
(
is.character
(
modelfile
))
{
if
(
is.character
(
modelfile
))
{
...
@@ -46,7 +46,7 @@ Predictor <- R6::R6Class(
...
@@ -46,7 +46,7 @@ Predictor <- R6::R6Class(
handle
<-
modelfile
handle
<-
modelfile
private
$
need_free_handle
<-
FALSE
private
$
need_free_handle
<-
FALSE
}
else
if
(
lgb
.is
.
Booster
(
modelfile
))
{
}
else
if
(
.is
_
Booster
(
modelfile
))
{
handle
<-
modelfile
$
get_handle
()
handle
<-
modelfile
$
get_handle
()
private
$
need_free_handle
<-
FALSE
private
$
need_free_handle
<-
FALSE
...
@@ -512,7 +512,7 @@ Predictor <- R6::R6Class(
...
@@ -512,7 +512,7 @@ Predictor <- R6::R6Class(
return
(
FALSE
)
return
(
FALSE
)
}
}
if
(
lgb
.is
.
null
.
handle
(
private
$
fast_predict_config
$
handle
))
{
if
(
.is
_
null
_
handle
(
private
$
fast_predict_config
$
handle
))
{
warning
(
paste0
(
"Model had fast CSR predict configuration, but it is inactive."
warning
(
paste0
(
"Model had fast CSR predict configuration, but it is inactive."
,
" Try re-generating it through 'lgb.configure_fast_predict'."
))
,
" Try re-generating it through 'lgb.configure_fast_predict'."
))
return
(
FALSE
)
return
(
FALSE
)
...
@@ -527,8 +527,8 @@ Predictor <- R6::R6Class(
...
@@ -527,8 +527,8 @@ Predictor <- R6::R6Class(
private
$
fast_predict_config
$
rawscore
==
rawscore
&&
private
$
fast_predict_config
$
rawscore
==
rawscore
&&
private
$
fast_predict_config
$
predleaf
==
predleaf
&&
private
$
fast_predict_config
$
predleaf
==
predleaf
&&
private
$
fast_predict_config
$
predcontrib
==
predcontrib
&&
private
$
fast_predict_config
$
predcontrib
==
predcontrib
&&
lgb
.equal
.
or
.
both
.
null
(
private
$
fast_predict_config
$
start_iteration
,
start_iteration
)
&&
.equal
_
or
_
both
_
null
(
private
$
fast_predict_config
$
start_iteration
,
start_iteration
)
&&
lgb
.equal
.
or
.
both
.
null
(
private
$
fast_predict_config
$
num_iteration
,
num_iteration
)
.equal
_
or
_
both
_
null
(
private
$
fast_predict_config
$
num_iteration
,
num_iteration
)
)
)
}
}
)
)
...
...
R-package/R/lgb.cv.R
View file @
694e41e4
...
@@ -99,7 +99,7 @@ lgb.cv <- function(params = list()
...
@@ -99,7 +99,7 @@ lgb.cv <- function(params = list()
}
}
# If 'data' is not an lgb.Dataset, try to construct one using 'label'
# If 'data' is not an lgb.Dataset, try to construct one using 'label'
if
(
!
lgb
.is
.
Dataset
(
x
=
data
))
{
if
(
!
.is
_
Dataset
(
x
=
data
))
{
if
(
is.null
(
label
))
{
if
(
is.null
(
label
))
{
stop
(
"'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'"
)
stop
(
"'label' must be provided for lgb.cv if 'data' is not an 'lgb.Dataset'"
)
}
}
...
@@ -110,27 +110,27 @@ lgb.cv <- function(params = list()
...
@@ -110,27 +110,27 @@ lgb.cv <- function(params = list()
# in `params`.
# in `params`.
# this ensures that the model stored with Booster$save() correctly represents
# this ensures that the model stored with Booster$save() correctly represents
# what was passed in
# what was passed in
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"verbosity"
main_param_name
=
"verbosity"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
verbose
,
alternative_kwarg_value
=
verbose
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"num_iterations"
main_param_name
=
"num_iterations"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
nrounds
,
alternative_kwarg_value
=
nrounds
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"metric"
main_param_name
=
"metric"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
NULL
,
alternative_kwarg_value
=
NULL
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"objective"
main_param_name
=
"objective"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
obj
,
alternative_kwarg_value
=
obj
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"early_stopping_round"
main_param_name
=
"early_stopping_round"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
early_stopping_rounds
,
alternative_kwarg_value
=
early_stopping_rounds
...
@@ -148,7 +148,7 @@ lgb.cv <- function(params = list()
...
@@ -148,7 +148,7 @@ lgb.cv <- function(params = list()
# (for backwards compatibility). If it is a list of functions, store
# (for backwards compatibility). If it is a list of functions, store
# all of them. This makes it possible to pass any mix of strings like "auc"
# all of them. This makes it possible to pass any mix of strings like "auc"
# and custom functions to eval
# and custom functions to eval
params
<-
lgb
.check
.
eval
(
params
=
params
,
eval
=
eval
)
params
<-
.check
_
eval
(
params
=
params
,
eval
=
eval
)
eval_functions
<-
list
(
NULL
)
eval_functions
<-
list
(
NULL
)
if
(
is.function
(
eval
))
{
if
(
is.function
(
eval
))
{
eval_functions
<-
list
(
eval
)
eval_functions
<-
list
(
eval
)
...
@@ -166,7 +166,7 @@ lgb.cv <- function(params = list()
...
@@ -166,7 +166,7 @@ lgb.cv <- function(params = list()
# Check for boosting from a trained model
# Check for boosting from a trained model
if
(
is.character
(
init_model
))
{
if
(
is.character
(
init_model
))
{
predictor
<-
Predictor
$
new
(
modelfile
=
init_model
)
predictor
<-
Predictor
$
new
(
modelfile
=
init_model
)
}
else
if
(
lgb
.is
.
Booster
(
x
=
init_model
))
{
}
else
if
(
.is
_
Booster
(
x
=
init_model
))
{
predictor
<-
init_model
$
to_predictor
()
predictor
<-
init_model
$
to_predictor
()
}
}
...
@@ -193,7 +193,7 @@ lgb.cv <- function(params = list()
...
@@ -193,7 +193,7 @@ lgb.cv <- function(params = list()
}
else
if
(
!
is.null
(
data
$
get_colnames
()))
{
}
else
if
(
!
is.null
(
data
$
get_colnames
()))
{
cnames
<-
data
$
get_colnames
()
cnames
<-
data
$
get_colnames
()
}
}
params
[[
"interaction_constraints"
]]
<-
lgb
.check_interaction_constraints
(
params
[[
"interaction_constraints"
]]
<-
.check_interaction_constraints
(
interaction_constraints
=
interaction_constraints
interaction_constraints
=
interaction_constraints
,
column_names
=
cnames
,
column_names
=
cnames
)
)
...
@@ -232,7 +232,7 @@ lgb.cv <- function(params = list()
...
@@ -232,7 +232,7 @@ lgb.cv <- function(params = list()
}
}
# Create folds
# Create folds
folds
<-
generate
.
cv
.
folds
(
folds
<-
.
generate
_
cv
_
folds
(
nfold
=
nfold
nfold
=
nfold
,
nrows
=
nrow
(
data
)
,
nrows
=
nrow
(
data
)
,
stratified
=
stratified
,
stratified
=
stratified
...
@@ -245,12 +245,12 @@ lgb.cv <- function(params = list()
...
@@ -245,12 +245,12 @@ lgb.cv <- function(params = list()
# Add printing log callback
# Add printing log callback
if
(
params
[[
"verbosity"
]]
>
0L
&&
eval_freq
>
0L
)
{
if
(
params
[[
"verbosity"
]]
>
0L
&&
eval_freq
>
0L
)
{
callbacks
<-
add
.
cb
(
cb_list
=
callbacks
,
cb
=
cb_print_evaluation
(
period
=
eval_freq
))
callbacks
<-
.
add
_
cb
(
cb_list
=
callbacks
,
cb
=
cb_print_evaluation
(
period
=
eval_freq
))
}
}
# Add evaluation log callback
# Add evaluation log callback
if
(
record
)
{
if
(
record
)
{
callbacks
<-
add
.
cb
(
cb_list
=
callbacks
,
cb
=
cb_record_evaluation
())
callbacks
<-
.
add
_
cb
(
cb_list
=
callbacks
,
cb
=
cb_record_evaluation
())
}
}
# Did user pass parameters that indicate they want to use early stopping?
# Did user pass parameters that indicate they want to use early stopping?
...
@@ -282,7 +282,7 @@ lgb.cv <- function(params = list()
...
@@ -282,7 +282,7 @@ lgb.cv <- function(params = list()
# If user supplied early_stopping_rounds, add the early stopping callback
# If user supplied early_stopping_rounds, add the early stopping callback
if
(
using_early_stopping
)
{
if
(
using_early_stopping
)
{
callbacks
<-
add
.
cb
(
callbacks
<-
.
add
_
cb
(
cb_list
=
callbacks
cb_list
=
callbacks
,
cb
=
cb_early_stop
(
,
cb
=
cb_early_stop
(
stopping_rounds
=
early_stopping_rounds
stopping_rounds
=
early_stopping_rounds
...
@@ -292,7 +292,7 @@ lgb.cv <- function(params = list()
...
@@ -292,7 +292,7 @@ lgb.cv <- function(params = list()
)
)
}
}
cb
<-
categorize
.
callbacks
(
cb_list
=
callbacks
)
cb
<-
.
categorize
_
callbacks
(
cb_list
=
callbacks
)
# Construct booster for each fold. The data.table() code below is used to
# Construct booster for each fold. The data.table() code below is used to
# guarantee that indices are sorted while keeping init_score and weight together
# guarantee that indices are sorted while keeping init_score and weight together
...
@@ -387,7 +387,7 @@ lgb.cv <- function(params = list()
...
@@ -387,7 +387,7 @@ lgb.cv <- function(params = list()
})
})
# Prepare collection of evaluation results
# Prepare collection of evaluation results
merged_msg
<-
lgb
.merge
.
cv
.
result
(
merged_msg
<-
.merge
_
cv
_
result
(
msg
=
msg
msg
=
msg
,
showsd
=
showsd
,
showsd
=
showsd
)
)
...
@@ -463,7 +463,7 @@ lgb.cv <- function(params = list()
...
@@ -463,7 +463,7 @@ lgb.cv <- function(params = list()
}
}
# Generates random (stratified if needed) CV folds
# Generates random (stratified if needed) CV folds
generate
.
cv
.
folds
<-
function
(
nfold
,
nrows
,
stratified
,
label
,
group
,
params
)
{
.
generate
_
cv
_
folds
<-
function
(
nfold
,
nrows
,
stratified
,
label
,
group
,
params
)
{
# Check for group existence
# Check for group existence
if
(
is.null
(
group
))
{
if
(
is.null
(
group
))
{
...
@@ -476,7 +476,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
...
@@ -476,7 +476,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
y
<-
label
[
rnd_idx
]
y
<-
label
[
rnd_idx
]
y
<-
as.factor
(
y
)
y
<-
as.factor
(
y
)
folds
<-
lgb
.stratified
.
folds
(
y
=
y
,
k
=
nfold
)
folds
<-
.stratified
_
folds
(
y
=
y
,
k
=
nfold
)
}
else
{
}
else
{
...
@@ -528,7 +528,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
...
@@ -528,7 +528,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) {
# It was borrowed from caret::createFolds and simplified
# It was borrowed from caret::createFolds and simplified
# by always returning an unnamed list of fold indices.
# by always returning an unnamed list of fold indices.
#' @importFrom stats quantile
#' @importFrom stats quantile
lgb
.stratified
.
folds
<-
function
(
y
,
k
)
{
.stratified
_
folds
<-
function
(
y
,
k
)
{
# Group the numeric data based on their magnitudes
# Group the numeric data based on their magnitudes
# and sample within those groups.
# and sample within those groups.
...
@@ -594,7 +594,7 @@ lgb.stratified.folds <- function(y, k) {
...
@@ -594,7 +594,7 @@ lgb.stratified.folds <- function(y, k) {
return
(
out
)
return
(
out
)
}
}
lgb
.merge
.
cv
.
result
<-
function
(
msg
,
showsd
)
{
.merge
_
cv
_
result
<-
function
(
msg
,
showsd
)
{
if
(
length
(
msg
)
==
0L
)
{
if
(
length
(
msg
)
==
0L
)
{
stop
(
"lgb.cv: size of cv result error"
)
stop
(
"lgb.cv: size of cv result error"
)
...
...
R-package/R/lgb.drop_serialized.R
View file @
694e41e4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#' @seealso \link{lgb.restore_handle}, \link{lgb.make_serializable}.
#' @seealso \link{lgb.restore_handle}, \link{lgb.make_serializable}.
#' @export
#' @export
lgb.drop_serialized
<-
function
(
model
)
{
lgb.drop_serialized
<-
function
(
model
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
model
))
{
if
(
!
.is
_
Booster
(
x
=
model
))
{
stop
(
"lgb.drop_serialized: model should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"lgb.drop_serialized: model should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
model
$
drop_raw
()
model
$
drop_raw
()
...
...
R-package/R/lgb.importance.R
View file @
694e41e4
...
@@ -39,7 +39,7 @@
...
@@ -39,7 +39,7 @@
#' @export
#' @export
lgb.importance
<-
function
(
model
,
percentage
=
TRUE
)
{
lgb.importance
<-
function
(
model
,
percentage
=
TRUE
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
model
))
{
if
(
!
.is
_
Booster
(
x
=
model
))
{
stop
(
"'model' has to be an object of class lgb.Booster"
)
stop
(
"'model' has to be an object of class lgb.Booster"
)
}
}
...
...
R-package/R/lgb.interprete.R
View file @
694e41e4
...
@@ -86,7 +86,7 @@ lgb.interprete <- function(model,
...
@@ -86,7 +86,7 @@ lgb.interprete <- function(model,
)
)
for
(
i
in
seq_along
(
idxset
))
{
for
(
i
in
seq_along
(
idxset
))
{
tree_interpretation_dt_list
[[
i
]]
<-
single
.
row
.
interprete
(
tree_interpretation_dt_list
[[
i
]]
<-
.
single
_
row
_
interprete
(
tree_dt
=
tree_dt
tree_dt
=
tree_dt
,
num_class
=
num_class
,
num_class
=
num_class
,
tree_index_mat
=
tree_index_mat_list
[[
i
]]
,
tree_index_mat
=
tree_index_mat_list
[[
i
]]
...
@@ -151,7 +151,7 @@ single.tree.interprete <- function(tree_dt,
...
@@ -151,7 +151,7 @@ single.tree.interprete <- function(tree_dt,
}
}
#' @importFrom data.table := rbindlist setorder
#' @importFrom data.table := rbindlist setorder
multiple
.
tree
.
interprete
<-
function
(
tree_dt
,
.
multiple
_
tree
_
interprete
<-
function
(
tree_dt
,
tree_index
,
tree_index
,
leaf_index
)
{
leaf_index
)
{
...
@@ -186,7 +186,7 @@ multiple.tree.interprete <- function(tree_dt,
...
@@ -186,7 +186,7 @@ multiple.tree.interprete <- function(tree_dt,
}
}
#' @importFrom data.table set setnames
#' @importFrom data.table set setnames
single
.
row
.
interprete
<-
function
(
tree_dt
,
num_class
,
tree_index_mat
,
leaf_index_mat
)
{
.
single
_
row
_
interprete
<-
function
(
tree_dt
,
num_class
,
tree_index_mat
,
leaf_index_mat
)
{
# Prepare vector list
# Prepare vector list
tree_interpretation
<-
vector
(
mode
=
"list"
,
length
=
num_class
)
tree_interpretation
<-
vector
(
mode
=
"list"
,
length
=
num_class
)
...
@@ -194,7 +194,7 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index
...
@@ -194,7 +194,7 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index
# Loop throughout each class
# Loop throughout each class
for
(
i
in
seq_len
(
num_class
))
{
for
(
i
in
seq_len
(
num_class
))
{
next_interp_dt
<-
multiple
.
tree
.
interprete
(
next_interp_dt
<-
.
multiple
_
tree
_
interprete
(
tree_dt
=
tree_dt
tree_dt
=
tree_dt
,
tree_index
=
tree_index_mat
[,
i
]
,
tree_index
=
tree_index_mat
[,
i
]
,
leaf_index
=
leaf_index_mat
[,
i
]
,
leaf_index
=
leaf_index_mat
[,
i
]
...
...
R-package/R/lgb.make_serializable.R
View file @
694e41e4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#' @seealso \link{lgb.restore_handle}, \link{lgb.drop_serialized}.
#' @seealso \link{lgb.restore_handle}, \link{lgb.drop_serialized}.
#' @export
#' @export
lgb.make_serializable
<-
function
(
model
)
{
lgb.make_serializable
<-
function
(
model
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
model
))
{
if
(
!
.is
_
Booster
(
x
=
model
))
{
stop
(
"lgb.make_serializable: model should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"lgb.make_serializable: model should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
model
$
save_raw
()
model
$
save_raw
()
...
...
R-package/R/lgb.model.dt.tree.R
View file @
694e41e4
...
@@ -62,7 +62,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
...
@@ -62,7 +62,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
)
)
# Parse tree model
# Parse tree model
tree_list
<-
lapply
(
parsed_json_model
$
tree_info
,
single.tree.parse
)
tree_list
<-
lapply
(
X
=
parsed_json_model
$
tree_info
,
FUN
=
.single_tree_parse
)
# Combine into single data.table
# Combine into single data.table
tree_dt
<-
data.table
::
rbindlist
(
l
=
tree_list
,
use.names
=
TRUE
)
tree_dt
<-
data.table
::
rbindlist
(
l
=
tree_list
,
use.names
=
TRUE
)
...
@@ -84,7 +87,7 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
...
@@ -84,7 +87,7 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
#' @importFrom data.table := data.table rbindlist
#' @importFrom data.table := data.table rbindlist
single
.
tree
.
parse
<-
function
(
lgb_tree
)
{
.
single
_
tree
_
parse
<-
function
(
lgb_tree
)
{
# Traverse tree function
# Traverse tree function
pre_order_traversal
<-
function
(
env
=
NULL
,
tree_node_leaf
,
current_depth
=
0L
,
parent_index
=
NA_integer_
)
{
pre_order_traversal
<-
function
(
env
=
NULL
,
tree_node_leaf
,
current_depth
=
0L
,
parent_index
=
NA_integer_
)
{
...
...
R-package/R/lgb.plot.interpretation.R
View file @
694e41e4
...
@@ -89,7 +89,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
...
@@ -89,7 +89,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
if
(
num_class
==
1L
)
{
if
(
num_class
==
1L
)
{
# Only one class, plot straight away
# Only one class, plot straight away
multiple
.
tree
.
plot
.
interpretation
(
.
multiple
_
tree
_
plot
_
interpretation
(
tree_interpretation
=
tree_interpretation_dt
tree_interpretation
=
tree_interpretation_dt
,
top_n
=
top_n
,
top_n
=
top_n
,
title
=
NULL
,
title
=
NULL
...
@@ -118,7 +118,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
...
@@ -118,7 +118,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
,
old
=
names
(
plot_dt
)
,
old
=
names
(
plot_dt
)
,
new
=
c
(
"Feature"
,
"Contribution"
)
,
new
=
c
(
"Feature"
,
"Contribution"
)
)
)
multiple
.
tree
.
plot
.
interpretation
(
.
multiple
_
tree
_
plot
_
interpretation
(
tree_interpretation
=
plot_dt
tree_interpretation
=
plot_dt
,
top_n
=
top_n
,
top_n
=
top_n
,
title
=
paste
(
"Class"
,
i
-
1L
)
,
title
=
paste
(
"Class"
,
i
-
1L
)
...
@@ -131,7 +131,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
...
@@ -131,7 +131,7 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
}
}
#' @importFrom graphics barplot
#' @importFrom graphics barplot
multiple
.
tree
.
plot
.
interpretation
<-
function
(
tree_interpretation
,
.
multiple
_
tree
_
plot
_
interpretation
<-
function
(
tree_interpretation
,
top_n
,
top_n
,
title
,
title
,
cex
)
{
cex
)
{
...
...
R-package/R/lgb.restore_handle.R
View file @
694e41e4
...
@@ -35,7 +35,7 @@
...
@@ -35,7 +35,7 @@
#' model_new$check_null_handle()
#' model_new$check_null_handle()
#' @export
#' @export
lgb.restore_handle
<-
function
(
model
)
{
lgb.restore_handle
<-
function
(
model
)
{
if
(
!
lgb
.is
.
Booster
(
x
=
model
))
{
if
(
!
.is
_
Booster
(
x
=
model
))
{
stop
(
"lgb.restore_handle: model should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"lgb.restore_handle: model should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
model
$
restore_handle
()
model
$
restore_handle
()
...
...
R-package/R/lgb.train.R
View file @
694e41e4
...
@@ -63,11 +63,11 @@ lgb.train <- function(params = list(),
...
@@ -63,11 +63,11 @@ lgb.train <- function(params = list(),
if
(
nrounds
<=
0L
)
{
if
(
nrounds
<=
0L
)
{
stop
(
"nrounds should be greater than zero"
)
stop
(
"nrounds should be greater than zero"
)
}
}
if
(
!
lgb
.is
.
Dataset
(
x
=
data
))
{
if
(
!
.is
_
Dataset
(
x
=
data
))
{
stop
(
"lgb.train: data must be an lgb.Dataset instance"
)
stop
(
"lgb.train: data must be an lgb.Dataset instance"
)
}
}
if
(
length
(
valids
)
>
0L
)
{
if
(
length
(
valids
)
>
0L
)
{
if
(
!
identical
(
class
(
valids
),
"list"
)
||
!
all
(
vapply
(
valids
,
lgb
.is
.
Dataset
,
logical
(
1L
))))
{
if
(
!
identical
(
class
(
valids
),
"list"
)
||
!
all
(
vapply
(
valids
,
.is
_
Dataset
,
logical
(
1L
))))
{
stop
(
"lgb.train: valids must be a list of lgb.Dataset elements"
)
stop
(
"lgb.train: valids must be a list of lgb.Dataset elements"
)
}
}
evnames
<-
names
(
valids
)
evnames
<-
names
(
valids
)
...
@@ -80,27 +80,27 @@ lgb.train <- function(params = list(),
...
@@ -80,27 +80,27 @@ lgb.train <- function(params = list(),
# in `params`.
# in `params`.
# this ensures that the model stored with Booster$save() correctly represents
# this ensures that the model stored with Booster$save() correctly represents
# what was passed in
# what was passed in
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"verbosity"
main_param_name
=
"verbosity"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
verbose
,
alternative_kwarg_value
=
verbose
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"num_iterations"
main_param_name
=
"num_iterations"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
nrounds
,
alternative_kwarg_value
=
nrounds
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"metric"
main_param_name
=
"metric"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
NULL
,
alternative_kwarg_value
=
NULL
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"objective"
main_param_name
=
"objective"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
obj
,
alternative_kwarg_value
=
obj
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"early_stopping_round"
main_param_name
=
"early_stopping_round"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
early_stopping_rounds
,
alternative_kwarg_value
=
early_stopping_rounds
...
@@ -118,7 +118,7 @@ lgb.train <- function(params = list(),
...
@@ -118,7 +118,7 @@ lgb.train <- function(params = list(),
# (for backwards compatibility). If it is a list of functions, store
# (for backwards compatibility). If it is a list of functions, store
# all of them. This makes it possible to pass any mix of strings like "auc"
# all of them. This makes it possible to pass any mix of strings like "auc"
# and custom functions to eval
# and custom functions to eval
params
<-
lgb
.check
.
eval
(
params
=
params
,
eval
=
eval
)
params
<-
.check
_
eval
(
params
=
params
,
eval
=
eval
)
eval_functions
<-
list
(
NULL
)
eval_functions
<-
list
(
NULL
)
if
(
is.function
(
eval
))
{
if
(
is.function
(
eval
))
{
eval_functions
<-
list
(
eval
)
eval_functions
<-
list
(
eval
)
...
@@ -136,7 +136,7 @@ lgb.train <- function(params = list(),
...
@@ -136,7 +136,7 @@ lgb.train <- function(params = list(),
# Check for boosting from a trained model
# Check for boosting from a trained model
if
(
is.character
(
init_model
))
{
if
(
is.character
(
init_model
))
{
predictor
<-
Predictor
$
new
(
modelfile
=
init_model
)
predictor
<-
Predictor
$
new
(
modelfile
=
init_model
)
}
else
if
(
lgb
.is
.
Booster
(
x
=
init_model
))
{
}
else
if
(
.is
_
Booster
(
x
=
init_model
))
{
predictor
<-
init_model
$
to_predictor
()
predictor
<-
init_model
$
to_predictor
()
}
}
...
@@ -166,7 +166,7 @@ lgb.train <- function(params = list(),
...
@@ -166,7 +166,7 @@ lgb.train <- function(params = list(),
}
else
if
(
!
is.null
(
data
$
get_colnames
()))
{
}
else
if
(
!
is.null
(
data
$
get_colnames
()))
{
cnames
<-
data
$
get_colnames
()
cnames
<-
data
$
get_colnames
()
}
}
params
[[
"interaction_constraints"
]]
<-
lgb
.check_interaction_constraints
(
params
[[
"interaction_constraints"
]]
<-
.check_interaction_constraints
(
interaction_constraints
=
interaction_constraints
interaction_constraints
=
interaction_constraints
,
column_names
=
cnames
,
column_names
=
cnames
)
)
...
@@ -212,12 +212,18 @@ lgb.train <- function(params = list(),
...
@@ -212,12 +212,18 @@ lgb.train <- function(params = list(),
# Add printing log callback
# Add printing log callback
if
(
params
[[
"verbosity"
]]
>
0L
&&
eval_freq
>
0L
)
{
if
(
params
[[
"verbosity"
]]
>
0L
&&
eval_freq
>
0L
)
{
callbacks
<-
add.cb
(
cb_list
=
callbacks
,
cb
=
cb_print_evaluation
(
period
=
eval_freq
))
callbacks
<-
.add_cb
(
cb_list
=
callbacks
,
cb
=
cb_print_evaluation
(
period
=
eval_freq
)
)
}
}
# Add evaluation log callback
# Add evaluation log callback
if
(
record
&&
length
(
valids
)
>
0L
)
{
if
(
record
&&
length
(
valids
)
>
0L
)
{
callbacks
<-
add.cb
(
cb_list
=
callbacks
,
cb
=
cb_record_evaluation
())
callbacks
<-
.add_cb
(
cb_list
=
callbacks
,
cb
=
cb_record_evaluation
()
)
}
}
# Did user pass parameters that indicate they want to use early stopping?
# Did user pass parameters that indicate they want to use early stopping?
...
@@ -249,7 +255,7 @@ lgb.train <- function(params = list(),
...
@@ -249,7 +255,7 @@ lgb.train <- function(params = list(),
# If user supplied early_stopping_rounds, add the early stopping callback
# If user supplied early_stopping_rounds, add the early stopping callback
if
(
using_early_stopping
)
{
if
(
using_early_stopping
)
{
callbacks
<-
add
.
cb
(
callbacks
<-
.
add
_
cb
(
cb_list
=
callbacks
cb_list
=
callbacks
,
cb
=
cb_early_stop
(
,
cb
=
cb_early_stop
(
stopping_rounds
=
early_stopping_rounds
stopping_rounds
=
early_stopping_rounds
...
@@ -259,7 +265,7 @@ lgb.train <- function(params = list(),
...
@@ -259,7 +265,7 @@ lgb.train <- function(params = list(),
)
)
}
}
cb
<-
categorize
.
callbacks
(
cb_list
=
callbacks
)
cb
<-
.
categorize
_
callbacks
(
cb_list
=
callbacks
)
# Construct booster with datasets
# Construct booster with datasets
booster
<-
Booster
$
new
(
params
=
params
,
train_set
=
data
)
booster
<-
Booster
$
new
(
params
=
params
,
train_set
=
data
)
...
...
R-package/R/lightgbm.R
View file @
694e41e4
...
@@ -184,21 +184,21 @@ lightgbm <- function(data,
...
@@ -184,21 +184,21 @@ lightgbm <- function(data,
}
}
if
(
is.null
(
num_threads
))
{
if
(
is.null
(
num_threads
))
{
num_threads
<-
lgb
.get
.
default
.
num
.
threads
()
num_threads
<-
.get
_
default
_
num
_
threads
()
}
}
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"num_threads"
main_param_name
=
"num_threads"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
num_threads
,
alternative_kwarg_value
=
num_threads
)
)
params
<-
lgb
.check
.
wrapper_param
(
params
<-
.check
_
wrapper_param
(
main_param_name
=
"verbosity"
main_param_name
=
"verbosity"
,
params
=
params
,
params
=
params
,
alternative_kwarg_value
=
verbose
,
alternative_kwarg_value
=
verbose
)
)
# Process factors as labels and auto-determine objective
# Process factors as labels and auto-determine objective
if
(
!
lgb
.is
.
Dataset
(
data
))
{
if
(
!
.is
_
Dataset
(
data
))
{
data_processor
<-
DataProcessor
$
new
()
data_processor
<-
DataProcessor
$
new
()
temp
<-
data_processor
$
process_label
(
temp
<-
data_processor
$
process_label
(
label
=
label
label
=
label
...
@@ -220,7 +220,7 @@ lightgbm <- function(data,
...
@@ -220,7 +220,7 @@ lightgbm <- function(data,
dtrain
<-
data
dtrain
<-
data
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
# Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
if
(
!
lgb
.is
.
Dataset
(
x
=
dtrain
))
{
if
(
!
.is
_
Dataset
(
x
=
dtrain
))
{
dtrain
<-
lgb.Dataset
(
data
=
data
,
label
=
label
,
weight
=
weights
,
init_score
=
init_score
)
dtrain
<-
lgb.Dataset
(
data
=
data
,
label
=
label
,
weight
=
weights
,
init_score
=
init_score
)
}
}
...
...
R-package/R/saveRDS.lgb.Booster.R
View file @
694e41e4
...
@@ -57,7 +57,7 @@ saveRDS.lgb.Booster <- function(object,
...
@@ -57,7 +57,7 @@ saveRDS.lgb.Booster <- function(object,
warning
(
"'saveRDS.lgb.Booster' is deprecated and will be removed in a future release. Use saveRDS() instead."
)
warning
(
"'saveRDS.lgb.Booster' is deprecated and will be removed in a future release. Use saveRDS() instead."
)
if
(
!
lgb
.is
.
Booster
(
x
=
object
))
{
if
(
!
.is
_
Booster
(
x
=
object
))
{
stop
(
"saveRDS.lgb.Booster: object should be an "
,
sQuote
(
"lgb.Booster"
))
stop
(
"saveRDS.lgb.Booster: object should be an "
,
sQuote
(
"lgb.Booster"
))
}
}
...
...
R-package/R/utils.R
View file @
694e41e4
lgb
.is
.
Booster
<-
function
(
x
)
{
.is
_
Booster
<-
function
(
x
)
{
return
(
all
(
c
(
"R6"
,
"lgb.Booster"
)
%in%
class
(
x
)))
# nolint: class_equals
return
(
all
(
c
(
"R6"
,
"lgb.Booster"
)
%in%
class
(
x
)))
# nolint: class_equals
}
}
lgb
.is
.
Dataset
<-
function
(
x
)
{
.is
_
Dataset
<-
function
(
x
)
{
return
(
all
(
c
(
"R6"
,
"lgb.Dataset"
)
%in%
class
(
x
)))
# nolint: class_equals
return
(
all
(
c
(
"R6"
,
"lgb.Dataset"
)
%in%
class
(
x
)))
# nolint: class_equals
}
}
lgb
.is
.
Predictor
<-
function
(
x
)
{
.is
_
Predictor
<-
function
(
x
)
{
return
(
all
(
c
(
"R6"
,
"lgb.Predictor"
)
%in%
class
(
x
)))
# nolint: class_equals
return
(
all
(
c
(
"R6"
,
"lgb.Predictor"
)
%in%
class
(
x
)))
# nolint: class_equals
}
}
lgb
.is
.
null
.
handle
<-
function
(
x
)
{
.is
_
null
_
handle
<-
function
(
x
)
{
if
(
is.null
(
x
))
{
if
(
is.null
(
x
))
{
return
(
TRUE
)
return
(
TRUE
)
}
}
...
@@ -19,7 +19,7 @@ lgb.is.null.handle <- function(x) {
...
@@ -19,7 +19,7 @@ lgb.is.null.handle <- function(x) {
)
)
}
}
lgb
.params2str
<-
function
(
params
)
{
.params2str
<-
function
(
params
)
{
if
(
!
identical
(
class
(
params
),
"list"
))
{
if
(
!
identical
(
class
(
params
),
"list"
))
{
stop
(
"params must be a list"
)
stop
(
"params must be a list"
)
...
@@ -59,7 +59,7 @@ lgb.params2str <- function(params) {
...
@@ -59,7 +59,7 @@ lgb.params2str <- function(params) {
}
}
lgb
.check_interaction_constraints
<-
function
(
interaction_constraints
,
column_names
)
{
.check_interaction_constraints
<-
function
(
interaction_constraints
,
column_names
)
{
# Convert interaction constraints to feature numbers
# Convert interaction constraints to feature numbers
string_constraints
<-
list
()
string_constraints
<-
list
()
...
@@ -129,7 +129,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na
...
@@ -129,7 +129,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na
# This has to account for the fact that `eval` could be a character vector,
# This has to account for the fact that `eval` could be a character vector,
# a function, a list of functions, or a list with a mix of strings and
# a function, a list of functions, or a list with a mix of strings and
# functions
# functions
lgb
.check
.
eval
<-
function
(
params
,
eval
)
{
.check
_
eval
<-
function
(
params
,
eval
)
{
if
(
is.null
(
params
$
metric
))
{
if
(
is.null
(
params
$
metric
))
{
params
$
metric
<-
list
()
params
$
metric
<-
list
()
...
@@ -194,7 +194,7 @@ lgb.check.eval <- function(params, eval) {
...
@@ -194,7 +194,7 @@ lgb.check.eval <- function(params, eval) {
# [return]
# [return]
# params with num_iterations set to the chosen value, and other aliases
# params with num_iterations set to the chosen value, and other aliases
# of num_iterations removed
# of num_iterations removed
lgb
.check
.
wrapper_param
<-
function
(
main_param_name
,
params
,
alternative_kwarg_value
)
{
.check
_
wrapper_param
<-
function
(
main_param_name
,
params
,
alternative_kwarg_value
)
{
aliases
<-
.PARAMETER_ALIASES
()[[
main_param_name
]]
aliases
<-
.PARAMETER_ALIASES
()[[
main_param_name
]]
aliases_provided
<-
aliases
[
aliases
%in%
names
(
params
)]
aliases_provided
<-
aliases
[
aliases
%in%
names
(
params
)]
...
@@ -225,7 +225,7 @@ lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_v
...
@@ -225,7 +225,7 @@ lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_v
}
}
#' @importFrom parallel detectCores
#' @importFrom parallel detectCores
lgb
.get
.
default
.
num
.
threads
<-
function
()
{
.get
_
default
_
num
_
threads
<-
function
()
{
if
(
requireNamespace
(
"RhpcBLASctl"
,
quietly
=
TRUE
))
{
# nolint: undesirable_function
if
(
requireNamespace
(
"RhpcBLASctl"
,
quietly
=
TRUE
))
{
# nolint: undesirable_function
return
(
RhpcBLASctl
::
get_num_cores
())
return
(
RhpcBLASctl
::
get_num_cores
())
}
else
{
}
else
{
...
@@ -247,7 +247,7 @@ lgb.get.default.num.threads <- function() {
...
@@ -247,7 +247,7 @@ lgb.get.default.num.threads <- function() {
}
}
}
}
lgb
.equal
.
or
.
both
.
null
<-
function
(
a
,
b
)
{
.equal
_
or
_
both
_
null
<-
function
(
a
,
b
)
{
if
(
is.null
(
a
))
{
if
(
is.null
(
a
))
{
if
(
!
is.null
(
b
))
{
if
(
!
is.null
(
b
))
{
return
(
FALSE
)
return
(
FALSE
)
...
...
R-package/tests/testthat/test_Predictor.R
View file @
694e41e4
...
@@ -17,16 +17,16 @@ test_that("Predictor$finalize() should not fail", {
...
@@ -17,16 +17,16 @@ test_that("Predictor$finalize() should not fail", {
bst
$
save_model
(
filename
=
model_file
)
bst
$
save_model
(
filename
=
model_file
)
predictor
<-
Predictor
$
new
(
modelfile
=
model_file
)
predictor
<-
Predictor
$
new
(
modelfile
=
model_file
)
expect_true
(
lgb
.is
.
Predictor
(
predictor
))
expect_true
(
.is
_
Predictor
(
predictor
))
expect_false
(
lgb
.is
.
null
.
handle
(
predictor
$
.__enclos_env__
$
private
$
handle
))
expect_false
(
.is
_
null
_
handle
(
predictor
$
.__enclos_env__
$
private
$
handle
))
predictor
$
finalize
()
predictor
$
finalize
()
expect_true
(
lgb
.is
.
null
.
handle
(
predictor
$
.__enclos_env__
$
private
$
handle
))
expect_true
(
.is
_
null
_
handle
(
predictor
$
.__enclos_env__
$
private
$
handle
))
# calling finalize() a second time shouldn't cause any issues
# calling finalize() a second time shouldn't cause any issues
predictor
$
finalize
()
predictor
$
finalize
()
expect_true
(
lgb
.is
.
null
.
handle
(
predictor
$
.__enclos_env__
$
private
$
handle
))
expect_true
(
.is
_
null
_
handle
(
predictor
$
.__enclos_env__
$
private
$
handle
))
})
})
test_that
(
"predictions do not fail for integer input"
,
{
test_that
(
"predictions do not fail for integer input"
,
{
...
@@ -79,7 +79,7 @@ test_that("start_iteration works correctly", {
...
@@ -79,7 +79,7 @@ test_that("start_iteration works correctly", {
,
valids
=
list
(
"test"
=
dtest
)
,
valids
=
list
(
"test"
=
dtest
)
,
early_stopping_rounds
=
2L
,
early_stopping_rounds
=
2L
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
pred1
<-
predict
(
bst
,
newdata
=
test
$
data
,
type
=
"raw"
)
pred1
<-
predict
(
bst
,
newdata
=
test
$
data
,
type
=
"raw"
)
pred_contrib1
<-
predict
(
bst
,
test
$
data
,
type
=
"contrib"
)
pred_contrib1
<-
predict
(
bst
,
test
$
data
,
type
=
"contrib"
)
pred2
<-
rep
(
0.0
,
length
(
pred1
))
pred2
<-
rep
(
0.0
,
length
(
pred1
))
...
...
R-package/tests/testthat/test_basic.R
View file @
694e41e4
...
@@ -1094,7 +1094,7 @@ test_that("lgb.train() works as expected with sparse features", {
...
@@ -1094,7 +1094,7 @@ test_that("lgb.train() works as expected with sparse features", {
,
nrounds
=
nrounds
,
nrounds
=
nrounds
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
expect_equal
(
bst
$
current_iter
(),
nrounds
)
expect_equal
(
bst
$
current_iter
(),
nrounds
)
parsed_model
<-
jsonlite
::
fromJSON
(
bst
$
dump_model
())
parsed_model
<-
jsonlite
::
fromJSON
(
bst
$
dump_model
())
expect_equal
(
parsed_model
$
objective
,
"binary sigmoid:1"
)
expect_equal
(
parsed_model
$
objective
,
"binary sigmoid:1"
)
...
@@ -1816,7 +1816,7 @@ test_that("lgb.train() supports non-ASCII feature names", {
...
@@ -1816,7 +1816,7 @@ test_that("lgb.train() supports non-ASCII feature names", {
)
)
,
colnames
=
feature_names
,
colnames
=
feature_names
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
dumped_model
<-
jsonlite
::
fromJSON
(
bst
$
dump_model
())
dumped_model
<-
jsonlite
::
fromJSON
(
bst
$
dump_model
())
# UTF-8 strings are not well-supported on Windows
# UTF-8 strings are not well-supported on Windows
...
@@ -2522,7 +2522,7 @@ test_that("lgb.train() fit on linearly-relatead data improves when using linear
...
@@ -2522,7 +2522,7 @@ test_that("lgb.train() fit on linearly-relatead data improves when using linear
,
params
=
params
,
params
=
params
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
dtrain
<-
.new_dataset
()
dtrain
<-
.new_dataset
()
bst_linear
<-
lgb.train
(
bst_linear
<-
lgb.train
(
...
@@ -2531,7 +2531,7 @@ test_that("lgb.train() fit on linearly-relatead data improves when using linear
...
@@ -2531,7 +2531,7 @@ test_that("lgb.train() fit on linearly-relatead data improves when using linear
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst_linear
))
expect_true
(
.is
_
Booster
(
bst_linear
))
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
...
@@ -2599,7 +2599,7 @@ test_that("lgb.train() works with linear learners even if Dataset has missing va
...
@@ -2599,7 +2599,7 @@ test_that("lgb.train() works with linear learners even if Dataset has missing va
,
params
=
params
,
params
=
params
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
dtrain
<-
.new_dataset
()
dtrain
<-
.new_dataset
()
bst_linear
<-
lgb.train
(
bst_linear
<-
lgb.train
(
...
@@ -2608,7 +2608,7 @@ test_that("lgb.train() works with linear learners even if Dataset has missing va
...
@@ -2608,7 +2608,7 @@ test_that("lgb.train() works with linear learners even if Dataset has missing va
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst_linear
))
expect_true
(
.is
_
Booster
(
bst_linear
))
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
...
@@ -2649,7 +2649,7 @@ test_that("lgb.train() works with linear learners, bagging, and a Dataset that h
...
@@ -2649,7 +2649,7 @@ test_that("lgb.train() works with linear learners, bagging, and a Dataset that h
,
params
=
params
,
params
=
params
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
dtrain
<-
.new_dataset
()
dtrain
<-
.new_dataset
()
bst_linear
<-
lgb.train
(
bst_linear
<-
lgb.train
(
...
@@ -2658,7 +2658,7 @@ test_that("lgb.train() works with linear learners, bagging, and a Dataset that h
...
@@ -2658,7 +2658,7 @@ test_that("lgb.train() works with linear learners, bagging, and a Dataset that h
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst_linear
))
expect_true
(
.is
_
Booster
(
bst_linear
))
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
...
@@ -2699,7 +2699,7 @@ test_that("lgb.train() works with linear learners and data where a feature has o
...
@@ -2699,7 +2699,7 @@ test_that("lgb.train() works with linear learners and data where a feature has o
,
nrounds
=
10L
,
nrounds
=
10L
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
)
)
expect_true
(
lgb
.is
.
Booster
(
bst_linear
))
expect_true
(
.is
_
Booster
(
bst_linear
))
})
})
test_that
(
"lgb.train() works with linear learners when Dataset has categorical features"
,
{
test_that
(
"lgb.train() works with linear learners when Dataset has categorical features"
,
{
...
@@ -2732,7 +2732,7 @@ test_that("lgb.train() works with linear learners when Dataset has categorical f
...
@@ -2732,7 +2732,7 @@ test_that("lgb.train() works with linear learners when Dataset has categorical f
,
params
=
params
,
params
=
params
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst
))
expect_true
(
.is
_
Booster
(
bst
))
dtrain
<-
.new_dataset
()
dtrain
<-
.new_dataset
()
bst_linear
<-
lgb.train
(
bst_linear
<-
lgb.train
(
...
@@ -2741,7 +2741,7 @@ test_that("lgb.train() works with linear learners when Dataset has categorical f
...
@@ -2741,7 +2741,7 @@ test_that("lgb.train() works with linear learners when Dataset has categorical f
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
params
=
utils
::
modifyList
(
params
,
list
(
linear_tree
=
TRUE
))
,
valids
=
list
(
"train"
=
dtrain
)
,
valids
=
list
(
"train"
=
dtrain
)
)
)
expect_true
(
lgb
.is
.
Booster
(
bst_linear
))
expect_true
(
.is
_
Booster
(
bst_linear
))
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_last_mse
<-
bst
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
bst_lin_last_mse
<-
bst_linear
$
record_evals
[[
"train"
]][[
"l2"
]][[
"eval"
]][[
10L
]]
...
...
R-package/tests/testthat/test_dataset.R
View file @
694e41e4
...
@@ -206,7 +206,7 @@ test_that("lgb.Dataset: Dataset should be able to construct from matrix and retu
...
@@ -206,7 +206,7 @@ test_that("lgb.Dataset: Dataset should be able to construct from matrix and retu
,
rawData
,
rawData
,
nrow
(
rawData
)
,
nrow
(
rawData
)
,
ncol
(
rawData
)
,
ncol
(
rawData
)
,
lightgbm
:::
lgb
.params2str
(
params
=
list
())
,
lightgbm
:::
.params2str
(
params
=
list
())
,
ref_handle
,
ref_handle
)
)
expect_true
(
methods
::
is
(
handle
,
"externalptr"
))
expect_true
(
methods
::
is
(
handle
,
"externalptr"
))
...
@@ -322,7 +322,7 @@ test_that("Dataset$update_parameters() does nothing for empty inputs", {
...
@@ -322,7 +322,7 @@ test_that("Dataset$update_parameters() does nothing for empty inputs", {
res
<-
ds
$
update_params
(
res
<-
ds
$
update_params
(
params
=
list
()
params
=
list
()
)
)
expect_true
(
lgb
.is
.
Dataset
(
res
))
expect_true
(
.is
_
Dataset
(
res
))
new_params
<-
ds
$
get_params
()
new_params
<-
ds
$
get_params
()
expect_identical
(
new_params
,
initial_params
)
expect_identical
(
new_params
,
initial_params
)
...
@@ -343,7 +343,7 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame
...
@@ -343,7 +343,7 @@ test_that("Dataset$update_params() works correctly for recognized Dataset parame
res
<-
ds
$
update_params
(
res
<-
ds
$
update_params
(
params
=
new_params
params
=
new_params
)
)
expect_true
(
lgb
.is
.
Dataset
(
res
))
expect_true
(
.is
_
Dataset
(
res
))
updated_params
<-
ds
$
get_params
()
updated_params
<-
ds
$
get_params
()
for
(
param_name
in
names
(
new_params
))
{
for
(
param_name
in
names
(
new_params
))
{
...
@@ -356,17 +356,17 @@ test_that("Dataset$finalize() should not fail on an already-finalized Dataset",
...
@@ -356,17 +356,17 @@ test_that("Dataset$finalize() should not fail on an already-finalized Dataset",
data
=
test_data
data
=
test_data
,
label
=
test_label
,
label
=
test_label
)
)
expect_true
(
lgb
.is
.
null
.
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
expect_true
(
.is
_
null
_
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
dtest
$
construct
()
dtest
$
construct
()
expect_false
(
lgb
.is
.
null
.
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
expect_false
(
.is
_
null
_
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
dtest
$
finalize
()
dtest
$
finalize
()
expect_true
(
lgb
.is
.
null
.
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
expect_true
(
.is
_
null
_
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
# calling finalize() a second time shouldn't cause any issues
# calling finalize() a second time shouldn't cause any issues
dtest
$
finalize
()
dtest
$
finalize
()
expect_true
(
lgb
.is
.
null
.
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
expect_true
(
.is
_
null
_
handle
(
dtest
$
.__enclos_env__
$
private
$
handle
))
})
})
test_that
(
"lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file"
,
{
test_that
(
"lgb.Dataset: should be able to run lgb.train() immediately after using lgb.Dataset() on a file"
,
{
...
@@ -401,7 +401,7 @@ test_that("lgb.Dataset: should be able to run lgb.train() immediately after usin
...
@@ -401,7 +401,7 @@ test_that("lgb.Dataset: should be able to run lgb.train() immediately after usin
,
data
=
dtest_read_in
,
data
=
dtest_read_in
)
)
expect_true
(
lgb
.is
.
Booster
(
x
=
bst
))
expect_true
(
.is
_
Booster
(
x
=
bst
))
})
})
test_that
(
"lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file"
,
{
test_that
(
"lgb.Dataset: should be able to run lgb.cv() immediately after using lgb.Dataset() on a file"
,
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment