Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
1886bf51
Unverified
Commit
1886bf51
authored
Jul 10, 2024
by
Christian Bourjau
Committed by
GitHub
Jul 10, 2024
Browse files
[c++] Avoid copy on Refit (#6478)
parent
cd4459a1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
22 deletions
+31
-22
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+1
-1
src/application/application.cpp
src/application/application.cpp
+18
-5
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+10
-8
src/boosting/gbdt.h
src/boosting/gbdt.h
+1
-1
src/c_api.cpp
src/c_api.cpp
+1
-7
No files found.
include/LightGBM/boosting.h
View file @
1886bf51
...
@@ -74,7 +74,7 @@ class LIGHTGBM_EXPORT Boosting {
...
@@ -74,7 +74,7 @@ class LIGHTGBM_EXPORT Boosting {
/*!
/*!
* \brief Update the tree output by new training data
* \brief Update the tree output by new training data
*/
*/
virtual
void
RefitTree
(
const
std
::
vector
<
std
::
vector
<
int
>>&
tree_leaf_prediction
)
=
0
;
virtual
void
RefitTree
(
const
int
*
tree_leaf_prediction
,
const
size_t
nrow
,
const
size_t
ncol
)
=
0
;
/*!
/*!
* \brief Training logic
* \brief Training logic
...
...
src/application/application.cpp
View file @
1886bf51
...
@@ -226,12 +226,24 @@ void Application::Predict() {
...
@@ -226,12 +226,24 @@ void Application::Predict() {
config_
.
precise_float_parser
);
config_
.
precise_float_parser
);
TextReader
<
int
>
result_reader
(
config_
.
output_result
.
c_str
(),
false
);
TextReader
<
int
>
result_reader
(
config_
.
output_result
.
c_str
(),
false
);
result_reader
.
ReadAllLines
();
result_reader
.
ReadAllLines
();
std
::
vector
<
std
::
vector
<
int
>>
pred_leaf
(
result_reader
.
Lines
().
size
());
size_t
nrow
=
result_reader
.
Lines
().
size
();
size_t
ncol
=
0
;
if
(
nrow
>
0
)
{
ncol
=
Common
::
StringToArray
<
int
>
(
result_reader
.
Lines
()[
0
],
'\t'
).
size
();
}
std
::
vector
<
int
>
pred_leaf
;
pred_leaf
.
resize
(
nrow
*
ncol
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
result_reader
.
Lines
().
size
());
++
i
)
{
for
(
int
irow
=
0
;
irow
<
static_cast
<
int
>
(
nrow
);
++
irow
)
{
pred_leaf
[
i
]
=
Common
::
StringToArray
<
int
>
(
result_reader
.
Lines
()[
i
],
'\t'
);
auto
line_vec
=
Common
::
StringToArray
<
int
>
(
result_reader
.
Lines
()[
irow
],
'\t'
);
CHECK_EQ
(
line_vec
.
size
(),
ncol
);
for
(
int
i_row_item
=
0
;
i_row_item
<
static_cast
<
int
>
(
ncol
);
++
i_row_item
)
{
pred_leaf
[
irow
*
ncol
+
i_row_item
]
=
line_vec
[
i_row_item
];
}
// Free memory
// Free memory
result_reader
.
Lines
()[
i
].
clear
();
result_reader
.
Lines
()[
i
row
].
clear
();
}
}
DatasetLoader
dataset_loader
(
config_
,
nullptr
,
DatasetLoader
dataset_loader
(
config_
,
nullptr
,
config_
.
num_class
,
config_
.
data
.
c_str
());
config_
.
num_class
,
config_
.
data
.
c_str
());
...
@@ -242,7 +254,8 @@ void Application::Predict() {
...
@@ -242,7 +254,8 @@ void Application::Predict() {
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
objective_fun_
->
Init
(
train_data_
->
metadata
(),
train_data_
->
num_data
());
boosting_
->
Init
(
&
config_
,
train_data_
.
get
(),
objective_fun_
.
get
(),
boosting_
->
Init
(
&
config_
,
train_data_
.
get
(),
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
boosting_
->
RefitTree
(
pred_leaf
);
boosting_
->
RefitTree
(
pred_leaf
.
data
(),
nrow
,
ncol
);
boosting_
->
SaveModelToFile
(
0
,
-
1
,
config_
.
saved_feature_importance_type
,
boosting_
->
SaveModelToFile
(
0
,
-
1
,
config_
.
saved_feature_importance_type
,
config_
.
output_model
.
c_str
());
config_
.
output_model
.
c_str
());
Log
::
Info
(
"Finished RefitTree"
);
Log
::
Info
(
"Finished RefitTree"
);
...
...
src/boosting/gbdt.cpp
View file @
1886bf51
...
@@ -249,32 +249,34 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
...
@@ -249,32 +249,34 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
}
}
}
}
void
GBDT
::
RefitTree
(
const
std
::
vector
<
std
::
vector
<
int
>>&
tree_leaf_prediction
)
{
void
GBDT
::
RefitTree
(
const
int
*
tree_leaf_prediction
,
const
size_t
nrow
,
const
size_t
ncol
)
{
CHECK_GT
(
tree_leaf_prediction
.
size
(),
0
);
CHECK_GT
(
nrow
*
ncol
,
0
);
CHECK_EQ
(
static_cast
<
size_t
>
(
num_data_
),
tree_leaf_prediction
.
size
());
CHECK_EQ
(
static_cast
<
size_t
>
(
num_data_
),
nrow
);
CHECK_EQ
(
static_cast
<
size_t
>
(
models_
.
size
()),
tree_leaf_prediction
[
0
].
size
());
CHECK_EQ
(
models_
.
size
(),
ncol
);
int
num_iterations
=
static_cast
<
int
>
(
models_
.
size
()
/
num_tree_per_iteration_
);
int
num_iterations
=
static_cast
<
int
>
(
models_
.
size
()
/
num_tree_per_iteration_
);
std
::
vector
<
int
>
leaf_pred
(
num_data_
);
std
::
vector
<
int
>
leaf_pred
(
num_data_
);
if
(
linear_tree_
)
{
if
(
linear_tree_
)
{
std
::
vector
<
int
>
max_leaves_by_thread
=
std
::
vector
<
int
>
(
OMP_NUM_THREADS
(),
0
);
std
::
vector
<
int
>
max_leaves_by_thread
=
std
::
vector
<
int
>
(
OMP_NUM_THREADS
(),
0
);
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
tree_leaf_prediction
.
size
()
);
++
i
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
nrow
);
++
i
)
{
int
tid
=
omp_get_thread_num
();
int
tid
=
omp_get_thread_num
();
for
(
size_t
j
=
0
;
j
<
tree_leaf_prediction
[
i
].
size
()
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
ncol
;
++
j
)
{
max_leaves_by_thread
[
tid
]
=
std
::
max
(
max_leaves_by_thread
[
tid
],
tree_leaf_prediction
[
i
][
j
]);
max_leaves_by_thread
[
tid
]
=
std
::
max
(
max_leaves_by_thread
[
tid
],
tree_leaf_prediction
[
i
*
ncol
+
j
]);
}
}
}
}
int
max_leaves
=
*
std
::
max_element
(
max_leaves_by_thread
.
begin
(),
max_leaves_by_thread
.
end
());
int
max_leaves
=
*
std
::
max_element
(
max_leaves_by_thread
.
begin
(),
max_leaves_by_thread
.
end
());
max_leaves
+=
1
;
max_leaves
+=
1
;
tree_learner_
->
InitLinear
(
train_data_
,
max_leaves
);
tree_learner_
->
InitLinear
(
train_data_
,
max_leaves
);
}
}
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
for
(
int
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
Boosting
();
Boosting
();
for
(
int
tree_id
=
0
;
tree_id
<
num_tree_per_iteration_
;
++
tree_id
)
{
for
(
int
tree_id
=
0
;
tree_id
<
num_tree_per_iteration_
;
++
tree_id
)
{
int
model_index
=
iter
*
num_tree_per_iteration_
+
tree_id
;
int
model_index
=
iter
*
num_tree_per_iteration_
+
tree_id
;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for
(
int
i
=
0
;
i
<
num_data_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_data_
;
++
i
)
{
leaf_pred
[
i
]
=
tree_leaf_prediction
[
i
][
model_index
];
leaf_pred
[
i
]
=
tree_leaf_prediction
[
i
*
ncol
+
model_index
];
CHECK_LT
(
leaf_pred
[
i
],
models_
[
model_index
]
->
num_leaves
());
CHECK_LT
(
leaf_pred
[
i
],
models_
[
model_index
]
->
num_leaves
());
}
}
size_t
offset
=
static_cast
<
size_t
>
(
tree_id
)
*
num_data_
;
size_t
offset
=
static_cast
<
size_t
>
(
tree_id
)
*
num_data_
;
...
...
src/boosting/gbdt.h
View file @
1886bf51
...
@@ -143,7 +143,7 @@ class GBDT : public GBDTBase {
...
@@ -143,7 +143,7 @@ class GBDT : public GBDTBase {
*/
*/
void
Train
(
int
snapshot_freq
,
const
std
::
string
&
model_output_path
)
override
;
void
Train
(
int
snapshot_freq
,
const
std
::
string
&
model_output_path
)
override
;
void
RefitTree
(
const
std
::
vector
<
std
::
vector
<
int
>>&
tree_leaf_prediction
)
override
;
void
RefitTree
(
const
int
*
tree_leaf_prediction
,
const
size_t
nrow
,
const
size_t
ncol
)
override
;
/*!
/*!
* \brief Training logic
* \brief Training logic
...
...
src/c_api.cpp
View file @
1886bf51
...
@@ -409,13 +409,7 @@ class Booster {
...
@@ -409,13 +409,7 @@ class Booster {
void
Refit
(
const
int32_t
*
leaf_preds
,
int32_t
nrow
,
int32_t
ncol
)
{
void
Refit
(
const
int32_t
*
leaf_preds
,
int32_t
nrow
,
int32_t
ncol
)
{
UNIQUE_LOCK
(
mutex_
)
UNIQUE_LOCK
(
mutex_
)
std
::
vector
<
std
::
vector
<
int32_t
>>
v_leaf_preds
(
nrow
,
std
::
vector
<
int32_t
>
(
ncol
,
0
));
boosting_
->
RefitTree
(
leaf_preds
,
nrow
,
ncol
);
for
(
int
i
=
0
;
i
<
nrow
;
++
i
)
{
for
(
int
j
=
0
;
j
<
ncol
;
++
j
)
{
v_leaf_preds
[
i
][
j
]
=
leaf_preds
[
static_cast
<
size_t
>
(
i
)
*
static_cast
<
size_t
>
(
ncol
)
+
static_cast
<
size_t
>
(
j
)];
}
}
boosting_
->
RefitTree
(
v_leaf_preds
);
}
}
bool
TrainOneIter
(
const
score_t
*
gradients
,
const
score_t
*
hessians
)
{
bool
TrainOneIter
(
const
score_t
*
gradients
,
const
score_t
*
hessians
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment