Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5dfe7168
Unverified
Commit
5dfe7168
authored
Apr 10, 2024
by
Michael Mayer
Committed by
GitHub
Apr 10, 2024
Browse files
[R-package] Speed-up lgb.importance() (#6364)
parent
628e91a9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
185 additions
and
29 deletions
+185
-29
R-package/R/lgb.model.dt.tree.R
R-package/R/lgb.model.dt.tree.R
+27
-29
R-package/tests/testthat/test_lgb.model.dt.tree.R
R-package/tests/testthat/test_lgb.model.dt.tree.R
+158
-0
No files found.
R-package/R/lgb.model.dt.tree.R
View file @
5dfe7168
...
...
@@ -90,6 +90,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
#' @importFrom data.table := data.table rbindlist
.single_tree_parse
<-
function
(
lgb_tree
)
{
tree_info_cols
<-
c
(
"split_index"
,
"split_feature"
,
"split_gain"
,
"threshold"
,
"decision_type"
,
"default_left"
,
"internal_value"
,
"internal_count"
)
# Traverse tree function
pre_order_traversal
<-
function
(
env
=
NULL
,
tree_node_leaf
,
current_depth
=
0L
,
parent_index
=
NA_integer_
)
{
...
...
@@ -97,7 +107,8 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
if
(
is.null
(
env
))
{
# Setup initial default data.table with default types
env
<-
new.env
(
parent
=
emptyenv
())
env
$
single_tree_dt
<-
data.table
::
data.table
(
env
$
single_tree_dt
<-
list
()
env
$
single_tree_dt
[[
1L
]]
<-
data.table
::
data.table
(
tree_index
=
integer
(
0L
)
,
depth
=
integer
(
0L
)
,
split_index
=
integer
(
0L
)
...
...
@@ -127,19 +138,10 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
if
(
!
is.null
(
tree_node_leaf
$
split_index
))
{
# update data.table
env
$
single_tree_dt
<-
data.table
::
rbindlist
(
l
=
list
(
env
$
single_tree_dt
,
c
(
tree_node_leaf
[
c
(
"split_index"
,
"split_feature"
,
"split_gain"
,
"threshold"
,
"decision_type"
,
"default_left"
,
"internal_value"
,
"internal_count"
)],
"depth"
=
current_depth
,
"node_parent"
=
parent_index
)),
use.names
=
TRUE
,
fill
=
TRUE
)
env
$
single_tree_dt
[[
length
(
env
$
single_tree_dt
)
+
1L
]]
<-
c
(
tree_node_leaf
[
tree_info_cols
]
,
list
(
"depth"
=
current_depth
,
"node_parent"
=
parent_index
)
)
# Traverse tree again both left and right
pre_order_traversal
(
...
...
@@ -154,31 +156,27 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
,
current_depth
=
current_depth
+
1L
,
parent_index
=
tree_node_leaf
$
split_index
)
}
else
if
(
!
is.null
(
tree_node_leaf
$
leaf_index
))
{
# update data.table
env
$
single_tree_dt
<-
data.table
::
rbindlist
(
l
=
list
(
env
$
single_tree_dt
,
c
(
tree_node_leaf
[
c
(
"leaf_index"
,
"leaf_value"
,
"leaf_count"
)],
"depth"
=
current_depth
,
"leaf_parent"
=
parent_index
)),
use.names
=
TRUE
,
fill
=
TRUE
)
# update list
env
$
single_tree_dt
[[
length
(
env
$
single_tree_dt
)
+
1L
]]
<-
c
(
tree_node_leaf
[
c
(
"leaf_index"
,
"leaf_value"
,
"leaf_count"
)]
,
list
(
"depth"
=
current_depth
,
"leaf_parent"
=
parent_index
)
)
}
}
return
(
env
$
single_tree_dt
)
}
# Traverse structure
single_tree_dt
<-
pre_order_traversal
(
tree_node_leaf
=
lgb_tree
$
tree_structure
)
# Traverse structure and rowbind everything
single_tree_dt
<-
data.table
::
rbindlist
(
pre_order_traversal
(
tree_node_leaf
=
lgb_tree
$
tree_structure
)
,
use.names
=
TRUE
,
fill
=
TRUE
)
# Store index
single_tree_dt
[,
tree_index
:=
lgb_tree
$
tree_index
]
return
(
single_tree_dt
)
}
R-package/tests/testthat/test_lgb.model.dt.tree.R
0 → 100644
View file @
5dfe7168
NROUNDS
<-
10L
MAX_DEPTH
<-
3L
N
<-
nrow
(
iris
)
X
<-
data.matrix
(
iris
[
2L
:
4L
])
FEAT
<-
colnames
(
X
)
NCLASS
<-
nlevels
(
iris
[,
5L
])
model_reg
<-
lgb.train
(
params
=
list
(
objective
=
"regression"
,
num_threads
=
.LGB_MAX_THREADS
,
max.depth
=
MAX_DEPTH
)
,
data
=
lgb.Dataset
(
X
,
label
=
iris
[,
1L
])
,
verbose
=
.LGB_VERBOSITY
,
nrounds
=
NROUNDS
)
model_binary
<-
lgb.train
(
params
=
list
(
objective
=
"binary"
,
num_threads
=
.LGB_MAX_THREADS
,
max.depth
=
MAX_DEPTH
)
,
data
=
lgb.Dataset
(
X
,
label
=
iris
[,
5L
]
==
"setosa"
)
,
verbose
=
.LGB_VERBOSITY
,
nrounds
=
NROUNDS
)
model_multiclass
<-
lgb.train
(
params
=
list
(
objective
=
"multiclass"
,
num_threads
=
.LGB_MAX_THREADS
,
max.depth
=
MAX_DEPTH
,
num_classes
=
NCLASS
)
,
data
=
lgb.Dataset
(
X
,
label
=
as.integer
(
iris
[,
5L
])
-
1L
)
,
verbose
=
.LGB_VERBOSITY
,
nrounds
=
NROUNDS
)
model_rank
<-
lgb.train
(
params
=
list
(
objective
=
"lambdarank"
,
num_threads
=
.LGB_MAX_THREADS
,
max.depth
=
MAX_DEPTH
,
lambdarank_truncation_level
=
3L
)
,
data
=
lgb.Dataset
(
X
,
label
=
as.integer
(
iris
[,
1L
]
>
5.8
)
,
group
=
rep
(
10L
,
times
=
15L
)
)
,
verbose
=
.LGB_VERBOSITY
,
nrounds
=
NROUNDS
)
models
<-
list
(
reg
=
model_reg
,
bin
=
model_binary
,
multi
=
model_multiclass
,
rank
=
model_rank
)
for
(
model_name
in
names
(
models
))
{
model
<-
models
[[
model_name
]]
expected_n_trees
<-
NROUNDS
if
(
model_name
==
"multi"
)
{
expected_n_trees
<-
NROUNDS
*
NCLASS
}
df
<-
as.data.frame
(
lgb.model.dt.tree
(
model
))
df_list
<-
split
(
df
,
f
=
df
$
tree_index
,
drop
=
TRUE
)
df_leaf
<-
df
[
!
is.na
(
df
$
leaf_index
),
]
df_internal
<-
df
[
is.na
(
df
$
leaf_index
),
]
test_that
(
"lgb.model.dt.tree() returns the right number of trees"
,
{
expect_equal
(
length
(
unique
(
df
$
tree_index
)),
expected_n_trees
)
})
test_that
(
"num_iteration can return less trees"
,
{
expect_equal
(
length
(
unique
(
lgb.model.dt.tree
(
model
,
num_iteration
=
2L
)
$
tree_index
))
,
2L
*
(
if
(
model_name
==
"multi"
)
NCLASS
else
1L
)
)
})
test_that
(
"Tree index from lgb.model.dt.tree() is in 0:(NROUNS-1)"
,
{
expect_equal
(
unique
(
df
$
tree_index
),
(
0L
:
(
expected_n_trees
-
1L
)))
})
test_that
(
"Depth calculated from lgb.model.dt.tree() respects max.depth"
,
{
expect_true
(
max
(
df
$
depth
)
<=
MAX_DEPTH
)
})
test_that
(
"Each tree from lgb.model.dt.tree() has single root node"
,
{
expect_equal
(
unname
(
sapply
(
df_list
,
function
(
df
)
sum
(
df
$
depth
==
0L
)))
,
rep
(
1L
,
expected_n_trees
)
)
})
test_that
(
"Each tree from lgb.model.dt.tree() has two depth 1 nodes"
,
{
expect_equal
(
unname
(
sapply
(
df_list
,
function
(
df
)
sum
(
df
$
depth
==
1L
)))
,
rep
(
2L
,
expected_n_trees
)
)
})
test_that
(
"leaves from lgb.model.dt.tree() do not have split info"
,
{
internal_node_cols
<-
c
(
"split_index"
,
"split_feature"
,
"split_gain"
,
"threshold"
,
"decision_type"
,
"default_left"
,
"internal_value"
,
"internal_count"
)
expect_true
(
all
(
is.na
(
df_leaf
[
internal_node_cols
])))
})
test_that
(
"leaves from lgb.model.dt.tree() have valid leaf info"
,
{
expect_true
(
all
(
df_leaf
$
leaf_index
%in%
0L
:
(
2.0
^
MAX_DEPTH
-
1.0
)))
expect_true
(
all
(
is.finite
(
df_leaf
$
leaf_value
)))
expect_true
(
all
(
df_leaf
$
leaf_count
>
0L
&
df_leaf
$
leaf_count
<=
N
))
})
test_that
(
"non-leaves from lgb.model.dt.tree() do not have leaf info"
,
{
leaf_node_cols
<-
c
(
"leaf_index"
,
"leaf_parent"
,
"leaf_value"
,
"leaf_count"
)
expect_true
(
all
(
is.na
(
df_internal
[
leaf_node_cols
])))
})
test_that
(
"non-leaves from lgb.model.dt.tree() have valid split info"
,
{
expect_true
(
all
(
sapply
(
split
(
df_internal
,
df_internal
$
tree_index
),
function
(
x
)
all
(
x
$
split_index
%in%
0L
:
(
nrow
(
x
)
-
1L
))
)
)
)
expect_true
(
all
(
df_internal
$
split_feature
%in%
FEAT
))
num_cols
<-
c
(
"split_gain"
,
"threshold"
,
"internal_value"
)
expect_true
(
all
(
is.finite
(
unlist
(
df_internal
[,
num_cols
]))))
# range of decision type?
expect_true
(
all
(
df_internal
$
default_left
%in%
c
(
TRUE
,
FALSE
)))
counts
<-
df_internal
$
internal_count
expect_true
(
all
(
counts
>
1L
&
counts
<=
N
))
})
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment