Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
7d9106d2
Unverified
Commit
7d9106d2
authored
Jul 04, 2024
by
Michael Mayer
Committed by
GitHub
Jul 03, 2024
Browse files
[R-package]: add num_trees_per_iter, num_trees, and num_iter methods (#6500)
parent
3a98ea13
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
253 additions
and
2 deletions
+253
-2
R-package/R/lgb.Booster.R
R-package/R/lgb.Booster.R
+40
-0
R-package/src/lightgbm_R.cpp
R-package/src/lightgbm_R.cpp
+23
-2
R-package/src/lightgbm_R.h
R-package/src/lightgbm_R.h
+22
-0
R-package/tests/testthat/test_lgb.Booster.R
R-package/tests/testthat/test_lgb.Booster.R
+168
-0
No files found.
R-package/R/lgb.Booster.R
View file @
7d9106d2
...
@@ -307,6 +307,46 @@ Booster <- R6::R6Class(
...
@@ -307,6 +307,46 @@ Booster <- R6::R6Class(
},
},
# Number of trees per iteration
num_trees_per_iter
=
function
()
{
self
$
restore_handle
()
trees_per_iter
<-
1L
.Call
(
LGBM_BoosterNumModelPerIteration_R
,
private
$
handle
,
trees_per_iter
)
return
(
trees_per_iter
)
},
# Total number of trees
num_trees
=
function
()
{
self
$
restore_handle
()
ntrees
<-
0L
.Call
(
LGBM_BoosterNumberOfTotalModel_R
,
private
$
handle
,
ntrees
)
return
(
ntrees
)
},
# Number of iterations (= rounds)
num_iter
=
function
()
{
ntrees
<-
self
$
num_trees
()
trees_per_iter
<-
self
$
num_trees_per_iter
()
return
(
ntrees
/
trees_per_iter
)
},
# Get upper bound
# Get upper bound
upper_bound
=
function
()
{
upper_bound
=
function
()
{
...
...
R-package/src/lightgbm_R.cpp
View file @
7d9106d2
...
@@ -763,8 +763,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
...
@@ -763,8 +763,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_END
();
R_API_END
();
}
}
SEXP
LGBM_BoosterGetCurrentIteration_R
(
SEXP
handle
,
SEXP
LGBM_BoosterGetCurrentIteration_R
(
SEXP
handle
,
SEXP
out
)
{
SEXP
out
)
{
R_API_BEGIN
();
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
_AssertBoosterHandleNotNull
(
handle
);
int
out_iteration
;
int
out_iteration
;
...
@@ -774,6 +773,26 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
...
@@ -774,6 +773,26 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
R_API_END
();
R_API_END
();
}
}
SEXP
LGBM_BoosterNumModelPerIteration_R
(
SEXP
handle
,
SEXP
out
)
{
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
int
models_per_iter
;
CHECK_CALL
(
LGBM_BoosterNumModelPerIteration
(
R_ExternalPtrAddr
(
handle
),
&
models_per_iter
));
INTEGER
(
out
)[
0
]
=
models_per_iter
;
return
R_NilValue
;
R_API_END
();
}
SEXP
LGBM_BoosterNumberOfTotalModel_R
(
SEXP
handle
,
SEXP
out
)
{
R_API_BEGIN
();
_AssertBoosterHandleNotNull
(
handle
);
int
total_models
;
CHECK_CALL
(
LGBM_BoosterNumberOfTotalModel
(
R_ExternalPtrAddr
(
handle
),
&
total_models
));
INTEGER
(
out
)[
0
]
=
total_models
;
return
R_NilValue
;
R_API_END
();
}
SEXP
LGBM_BoosterGetUpperBoundValue_R
(
SEXP
handle
,
SEXP
LGBM_BoosterGetUpperBoundValue_R
(
SEXP
handle
,
SEXP
out_result
)
{
SEXP
out_result
)
{
R_API_BEGIN
();
R_API_BEGIN
();
...
@@ -1431,6 +1450,8 @@ static const R_CallMethodDef CallEntries[] = {
...
@@ -1431,6 +1450,8 @@ static const R_CallMethodDef CallEntries[] = {
{
"LGBM_BoosterUpdateOneIterCustom_R"
,
(
DL_FUNC
)
&
LGBM_BoosterUpdateOneIterCustom_R
,
4
},
{
"LGBM_BoosterUpdateOneIterCustom_R"
,
(
DL_FUNC
)
&
LGBM_BoosterUpdateOneIterCustom_R
,
4
},
{
"LGBM_BoosterRollbackOneIter_R"
,
(
DL_FUNC
)
&
LGBM_BoosterRollbackOneIter_R
,
1
},
{
"LGBM_BoosterRollbackOneIter_R"
,
(
DL_FUNC
)
&
LGBM_BoosterRollbackOneIter_R
,
1
},
{
"LGBM_BoosterGetCurrentIteration_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetCurrentIteration_R
,
2
},
{
"LGBM_BoosterGetCurrentIteration_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetCurrentIteration_R
,
2
},
{
"LGBM_BoosterNumModelPerIteration_R"
,
(
DL_FUNC
)
&
LGBM_BoosterNumModelPerIteration_R
,
2
},
{
"LGBM_BoosterNumberOfTotalModel_R"
,
(
DL_FUNC
)
&
LGBM_BoosterNumberOfTotalModel_R
,
2
},
{
"LGBM_BoosterGetUpperBoundValue_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetUpperBoundValue_R
,
2
},
{
"LGBM_BoosterGetUpperBoundValue_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetUpperBoundValue_R
,
2
},
{
"LGBM_BoosterGetLowerBoundValue_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetLowerBoundValue_R
,
2
},
{
"LGBM_BoosterGetLowerBoundValue_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetLowerBoundValue_R
,
2
},
{
"LGBM_BoosterGetEvalNames_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetEvalNames_R
,
1
},
{
"LGBM_BoosterGetEvalNames_R"
,
(
DL_FUNC
)
&
LGBM_BoosterGetEvalNames_R
,
1
},
...
...
R-package/src/lightgbm_R.h
View file @
7d9106d2
...
@@ -384,6 +384,28 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
...
@@ -384,6 +384,28 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetCurrentIteration_R(
SEXP
out
SEXP
out
);
);
/*!
* \brief Get number of trees per iteration
* \param handle Booster handle
* \param out Number of trees per iteration
* \return R NULL value
*/
LIGHTGBM_C_EXPORT
SEXP
LGBM_BoosterNumModelPerIteration_R
(
SEXP
handle
,
SEXP
out
);
/*!
* \brief Get total number of trees
* \param handle Booster handle
* \param out Total number of trees of Booster
* \return R NULL value
*/
LIGHTGBM_C_EXPORT
SEXP
LGBM_BoosterNumberOfTotalModel_R
(
SEXP
handle
,
SEXP
out
);
/*!
/*!
* \brief Get model upper bound value.
* \brief Get model upper bound value.
* \param handle Handle of Booster
* \param handle Handle of Booster
...
...
R-package/tests/testthat/test_lgb.Booster.R
View file @
7d9106d2
...
@@ -623,6 +623,174 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
...
@@ -623,6 +623,174 @@ test_that("Booster$update() throws an informative error if you provide a non-Dat
},
regexp
=
"lgb.Booster.update: Only can use lgb.Dataset"
,
fixed
=
TRUE
)
},
regexp
=
"lgb.Booster.update: Only can use lgb.Dataset"
,
fixed
=
TRUE
)
})
})
test_that
(
"Booster$num_trees_per_iter() works as expected"
,
{
set.seed
(
708L
)
X
<-
data.matrix
(
iris
[
2L
:
4L
])
y_reg
<-
iris
[,
1L
]
y_binary
<-
as.integer
(
y_reg
>
median
(
y_reg
))
y_class
<-
as.integer
(
iris
[,
5L
])
-
1L
num_class
<-
3L
nrounds
<-
10L
# Regression and binary probabilistic classification (1 iteration = 1 tree)
fit_reg
<-
lgb.train
(
params
=
list
(
objective
=
"mse"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
)
,
data
=
lgb.Dataset
(
X
,
label
=
y_reg
)
,
nrounds
=
nrounds
)
fit_binary
<-
lgb.train
(
params
=
list
(
objective
=
"binary"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
)
,
data
=
lgb.Dataset
(
X
,
label
=
y_binary
)
,
nrounds
=
nrounds
)
# Multiclass probabilistic classification (1 iteration = num_class trees)
fit_class
<-
lgb.train
(
params
=
list
(
objective
=
"multiclass"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
,
num_class
=
num_class
)
,
data
=
lgb.Dataset
(
X
,
label
=
y_class
)
,
nrounds
=
nrounds
)
expect_equal
(
fit_reg
$
num_trees_per_iter
(),
1L
)
expect_equal
(
fit_binary
$
num_trees_per_iter
(),
1L
)
expect_equal
(
fit_class
$
num_trees_per_iter
(),
num_class
)
})
test_that
(
"Booster$num_trees() and $num_iter() works (no early stopping)"
,
{
set.seed
(
708L
)
X
<-
data.matrix
(
iris
[
2L
:
4L
])
y_reg
<-
iris
[,
1L
]
y_binary
<-
as.integer
(
y_reg
>
median
(
y_reg
))
y_class
<-
as.integer
(
iris
[,
5L
])
-
1L
num_class
<-
3L
nrounds
<-
10L
# Regression and binary probabilistic classification (1 iteration = 1 tree)
fit_reg
<-
lgb.train
(
params
=
list
(
objective
=
"mse"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
)
,
data
=
lgb.Dataset
(
X
,
label
=
y_reg
)
,
nrounds
=
nrounds
)
fit_binary
<-
lgb.train
(
params
=
list
(
objective
=
"binary"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
)
,
data
=
lgb.Dataset
(
X
,
label
=
y_binary
)
,
nrounds
=
nrounds
)
# Multiclass probabilistic classification (1 iteration = num_class trees)
fit_class
<-
lgb.train
(
params
=
list
(
objective
=
"multiclass"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
,
num_class
=
num_class
)
,
data
=
lgb.Dataset
(
X
,
label
=
y_class
)
,
nrounds
=
nrounds
)
expect_equal
(
fit_reg
$
num_trees
(),
nrounds
)
expect_equal
(
fit_binary
$
num_trees
(),
nrounds
)
expect_equal
(
fit_class
$
num_trees
(),
num_class
*
nrounds
)
expect_equal
(
fit_reg
$
num_iter
(),
nrounds
)
expect_equal
(
fit_binary
$
num_iter
(),
nrounds
)
expect_equal
(
fit_class
$
num_iter
(),
nrounds
)
})
test_that
(
"Booster$num_trees() and $num_iter() work (with early stopping)"
,
{
set.seed
(
708L
)
X
<-
data.matrix
(
iris
[
2L
:
4L
])
y_reg
<-
iris
[,
1L
]
y_binary
<-
as.integer
(
y_reg
>
median
(
y_reg
))
y_class
<-
as.integer
(
iris
[,
5L
])
-
1L
train_ix
<-
c
(
1L
:
40L
,
51L
:
90L
,
101L
:
140L
)
X_train
<-
X
[
train_ix
,
]
X_valid
<-
X
[
-
train_ix
,
]
num_class
<-
3L
nrounds
<-
1000L
early_stopping
<-
2L
# Regression and binary probabilistic classification (1 iteration = 1 tree)
fit_reg
<-
lgb.train
(
params
=
list
(
objective
=
"mse"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
)
,
data
=
lgb.Dataset
(
X_train
,
label
=
y_reg
[
train_ix
])
,
valids
=
list
(
valid
=
lgb.Dataset
(
X_valid
,
label
=
y_reg
[
-
train_ix
]))
,
nrounds
=
nrounds
,
early_stopping_round
=
early_stopping
)
fit_binary
<-
lgb.train
(
params
=
list
(
objective
=
"binary"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
)
,
data
=
lgb.Dataset
(
X_train
,
label
=
y_binary
[
train_ix
])
,
valids
=
list
(
valid
=
lgb.Dataset
(
X_valid
,
label
=
y_binary
[
-
train_ix
]))
,
nrounds
=
nrounds
,
early_stopping_round
=
early_stopping
)
# Multiclass probabilistic classification (1 iteration = num_class trees)
fit_class
<-
lgb.train
(
params
=
list
(
objective
=
"multiclass"
,
verbose
=
.LGB_VERBOSITY
,
num_threads
=
.LGB_MAX_THREADS
,
num_class
=
num_class
)
,
data
=
lgb.Dataset
(
X_train
,
label
=
y_class
[
train_ix
])
,
valids
=
list
(
valid
=
lgb.Dataset
(
X_valid
,
label
=
y_class
[
-
train_ix
]))
,
nrounds
=
nrounds
,
early_stopping_round
=
early_stopping
)
expected_trees_reg
<-
fit_reg
$
best_iter
+
early_stopping
expected_trees_binary
<-
fit_binary
$
best_iter
+
early_stopping
expected_trees_class
<-
(
fit_class
$
best_iter
+
early_stopping
)
*
num_class
expect_equal
(
fit_reg
$
num_trees
(),
expected_trees_reg
)
expect_equal
(
fit_binary
$
num_trees
(),
expected_trees_binary
)
expect_equal
(
fit_class
$
num_trees
(),
expected_trees_class
)
expect_equal
(
fit_reg
$
num_iter
(),
expected_trees_reg
)
expect_equal
(
fit_binary
$
num_iter
(),
expected_trees_binary
)
expect_equal
(
fit_class
$
num_iter
(),
expected_trees_class
/
num_class
)
})
test_that
(
"Booster should store parameters and Booster$reset_parameter() should update them"
,
{
test_that
(
"Booster should store parameters and Booster$reset_parameter() should update them"
,
{
data
(
agaricus.train
,
package
=
"lightgbm"
)
data
(
agaricus.train
,
package
=
"lightgbm"
)
dtrain
<-
lgb.Dataset
(
dtrain
<-
lgb.Dataset
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment