Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
45e0da2c
Commit
45e0da2c
authored
Dec 28, 2016
by
Guolin Ke
Browse files
refine predictor logic in c_api
parent
728e50a9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
21 deletions
+10
-21
src/c_api.cpp
src/c_api.cpp
+10
-21
No files found.
src/c_api.cpp
View file @
45e0da2c
...
@@ -154,7 +154,7 @@ public:
...
@@ -154,7 +154,7 @@ public:
boosting_
->
RollbackOneIter
();
boosting_
->
RollbackOneIter
();
}
}
void
PrepareFor
Predict
ion
(
int
num_iteration
,
int
predict_type
)
{
Predictor
New
Predict
or
(
int
num_iteration
,
int
predict_type
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
boosting_
->
SetNumIterationForPred
(
num_iteration
);
boosting_
->
SetNumIterationForPred
(
num_iteration
);
bool
is_predict_leaf
=
false
;
bool
is_predict_leaf
=
false
;
...
@@ -166,22 +166,15 @@ public:
...
@@ -166,22 +166,15 @@ public:
}
else
{
}
else
{
is_raw_score
=
false
;
is_raw_score
=
false
;
}
}
predictor_
.
reset
(
new
Predictor
(
boosting_
.
get
(),
is_raw_score
,
is_predict_leaf
));
// not threading safe now
// boosting_->SetNumIterationForPred may be set by other thread during prediction.
return
Predictor
(
boosting_
.
get
(),
is_raw_score
,
is_predict_leaf
);
}
}
void
GetPredictAt
(
int
data_idx
,
score_t
*
out_result
,
int64_t
*
out_len
)
{
void
GetPredictAt
(
int
data_idx
,
score_t
*
out_result
,
int64_t
*
out_len
)
{
boosting_
->
GetPredictAt
(
data_idx
,
out_result
,
out_len
);
boosting_
->
GetPredictAt
(
data_idx
,
out_result
,
out_len
);
}
}
std
::
vector
<
double
>
Predict
(
const
std
::
vector
<
std
::
pair
<
int
,
double
>>&
features
)
{
return
predictor_
->
GetPredictFunction
()(
features
);
}
void
PredictForFile
(
const
char
*
data_filename
,
const
char
*
result_filename
,
bool
data_has_header
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
predictor_
->
Predict
(
data_filename
,
result_filename
,
data_has_header
);
}
void
SaveModelToFile
(
int
num_iteration
,
const
char
*
filename
)
{
void
SaveModelToFile
(
int
num_iteration
,
const
char
*
filename
)
{
boosting_
->
SaveModelToFile
(
num_iteration
,
filename
);
boosting_
->
SaveModelToFile
(
num_iteration
,
filename
);
}
}
...
@@ -232,8 +225,6 @@ private:
...
@@ -232,8 +225,6 @@ private:
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
Metric
>>>
valid_metrics_
;
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
Metric
>>>
valid_metrics_
;
/*! \brief Training objective function */
/*! \brief Training objective function */
std
::
unique_ptr
<
ObjectiveFunction
>
objective_fun_
;
std
::
unique_ptr
<
ObjectiveFunction
>
objective_fun_
;
/*! \brief Using predictor for prediction task */
std
::
unique_ptr
<
Predictor
>
predictor_
;
/*! \brief mutex for threading safe call */
/*! \brief mutex for threading safe call */
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
};
};
...
@@ -692,9 +683,9 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
...
@@ -692,9 +683,9 @@ DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
const
char
*
result_filename
)
{
const
char
*
result_filename
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
PrepareFor
Predict
ion
(
static_cast
<
int
>
(
num_iteration
),
predict_type
);
auto
predictor
=
ref_booster
->
New
Predict
or
(
static_cast
<
int
>
(
num_iteration
),
predict_type
);
bool
bool_data_has_header
=
data_has_header
>
0
?
true
:
false
;
bool
bool_data_has_header
=
data_has_header
>
0
?
true
:
false
;
re
f_booster
->
PredictForFile
(
data_filename
,
result_filename
,
bool_data_has_header
);
p
re
dictor
.
Predict
(
data_filename
,
result_filename
,
bool_data_has_header
);
API_END
();
API_END
();
}
}
...
@@ -713,8 +704,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
...
@@ -713,8 +704,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
float
*
out_result
)
{
float
*
out_result
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
PrepareForPrediction
(
static_cast
<
int
>
(
num_iteration
),
predict_type
);
auto
predictor
=
ref_booster
->
NewPredictor
(
static_cast
<
int
>
(
num_iteration
),
predict_type
);
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
int
num_preb_in_one_row
=
ref_booster
->
GetBoosting
()
->
NumberOfClasses
();
int
num_preb_in_one_row
=
ref_booster
->
GetBoosting
()
->
NumberOfClasses
();
if
(
predict_type
==
C_API_PREDICT_LEAF_INDEX
)
{
if
(
predict_type
==
C_API_PREDICT_LEAF_INDEX
)
{
...
@@ -728,7 +718,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
...
@@ -728,7 +718,7 @@ DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
nrow
;
++
i
)
{
for
(
int
i
=
0
;
i
<
nrow
;
++
i
)
{
auto
one_row
=
get_row_fun
(
i
);
auto
one_row
=
get_row_fun
(
i
);
auto
predicton_result
=
re
f_booster
->
Predict
(
one_row
);
auto
predicton_result
=
p
re
dictor
.
GetPredictFunction
()
(
one_row
);
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
predicton_result
.
size
());
++
j
)
{
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
predicton_result
.
size
());
++
j
)
{
out_result
[
i
*
num_preb_in_one_row
+
j
]
=
static_cast
<
float
>
(
predicton_result
[
j
]);
out_result
[
i
*
num_preb_in_one_row
+
j
]
=
static_cast
<
float
>
(
predicton_result
[
j
]);
}
}
...
@@ -749,8 +739,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
...
@@ -749,8 +739,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
float
*
out_result
)
{
float
*
out_result
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
PrepareForPrediction
(
static_cast
<
int
>
(
num_iteration
),
predict_type
);
auto
predictor
=
ref_booster
->
NewPredictor
(
static_cast
<
int
>
(
num_iteration
),
predict_type
);
auto
get_row_fun
=
RowPairFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
auto
get_row_fun
=
RowPairFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
int
num_preb_in_one_row
=
ref_booster
->
GetBoosting
()
->
NumberOfClasses
();
int
num_preb_in_one_row
=
ref_booster
->
GetBoosting
()
->
NumberOfClasses
();
if
(
predict_type
==
C_API_PREDICT_LEAF_INDEX
)
{
if
(
predict_type
==
C_API_PREDICT_LEAF_INDEX
)
{
...
@@ -763,7 +752,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
...
@@ -763,7 +752,7 @@ DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
nrow
;
++
i
)
{
for
(
int
i
=
0
;
i
<
nrow
;
++
i
)
{
auto
one_row
=
get_row_fun
(
i
);
auto
one_row
=
get_row_fun
(
i
);
auto
predicton_result
=
re
f_booster
->
Predict
(
one_row
);
auto
predicton_result
=
p
re
dictor
.
GetPredictFunction
()
(
one_row
);
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
predicton_result
.
size
());
++
j
)
{
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
predicton_result
.
size
());
++
j
)
{
out_result
[
i
*
num_preb_in_one_row
+
j
]
=
static_cast
<
float
>
(
predicton_result
[
j
]);
out_result
[
i
*
num_preb_in_one_row
+
j
]
=
static_cast
<
float
>
(
predicton_result
[
j
]);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment