Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
ac73638f
Commit
ac73638f
authored
Jul 06, 2017
by
Guolin Ke
Browse files
fix bug for csc prediction.
parent
5aa3ef4d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
16 deletions
+44
-16
src/c_api.cpp
src/c_api.cpp
+44
-16
No files found.
src/c_api.cpp
View file @
ac73638f
...
@@ -171,7 +171,7 @@ public:
...
@@ -171,7 +171,7 @@ public:
void
Predict
(
int
num_iteration
,
int
predict_type
,
int
nrow
,
void
Predict
(
int
num_iteration
,
int
predict_type
,
int
nrow
,
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
int
row_idx
)
>
get_row_fun
,
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
int
row_idx
)
>
get_row_fun
,
const
char
*
parameter
,
const
IOConfig
&
config
,
double
*
out_result
,
int64_t
*
out_len
)
{
double
*
out_result
,
int64_t
*
out_len
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
bool
is_predict_leaf
=
false
;
bool
is_predict_leaf
=
false
;
...
@@ -183,9 +183,7 @@ public:
...
@@ -183,9 +183,7 @@ public:
}
else
{
}
else
{
is_raw_score
=
false
;
is_raw_score
=
false
;
}
}
auto
param
=
ConfigBase
::
Str2Map
(
parameter
);
IOConfig
config
;
config
.
Set
(
param
);
Predictor
predictor
(
boosting_
.
get
(),
num_iteration
,
is_raw_score
,
is_predict_leaf
,
Predictor
predictor
(
boosting_
.
get
(),
num_iteration
,
is_raw_score
,
is_predict_leaf
,
config
.
pred_early_stop
,
config
.
pred_early_stop_freq
,
config
.
pred_early_stop_margin
);
config
.
pred_early_stop
,
config
.
pred_early_stop_freq
,
config
.
pred_early_stop_margin
);
int64_t
num_preb_in_one_row
=
boosting_
->
NumPredictOneRow
(
num_iteration
,
is_predict_leaf
);
int64_t
num_preb_in_one_row
=
boosting_
->
NumPredictOneRow
(
num_iteration
,
is_predict_leaf
);
...
@@ -204,7 +202,7 @@ public:
...
@@ -204,7 +202,7 @@ public:
}
}
void
Predict
(
int
num_iteration
,
int
predict_type
,
const
char
*
data_filename
,
void
Predict
(
int
num_iteration
,
int
predict_type
,
const
char
*
data_filename
,
int
data_has_header
,
const
char
*
parameter
,
int
data_has_header
,
const
IOConfig
&
config
,
const
char
*
result_filename
)
{
const
char
*
result_filename
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
bool
is_predict_leaf
=
false
;
bool
is_predict_leaf
=
false
;
...
@@ -216,9 +214,6 @@ public:
...
@@ -216,9 +214,6 @@ public:
}
else
{
}
else
{
is_raw_score
=
false
;
is_raw_score
=
false
;
}
}
auto
param
=
ConfigBase
::
Str2Map
(
parameter
);
IOConfig
config
;
config
.
Set
(
param
);
Predictor
predictor
(
boosting_
.
get
(),
num_iteration
,
is_raw_score
,
is_predict_leaf
,
Predictor
predictor
(
boosting_
.
get
(),
num_iteration
,
is_raw_score
,
is_predict_leaf
,
config
.
pred_early_stop
,
config
.
pred_early_stop_freq
,
config
.
pred_early_stop_margin
);
config
.
pred_early_stop
,
config
.
pred_early_stop_freq
,
config
.
pred_early_stop_margin
);
bool
bool_data_has_header
=
data_has_header
>
0
?
true
:
false
;
bool
bool_data_has_header
=
data_has_header
>
0
?
true
:
false
;
...
@@ -981,9 +976,15 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
...
@@ -981,9 +976,15 @@ int LGBM_BoosterPredictForFile(BoosterHandle handle,
const
char
*
parameter
,
const
char
*
parameter
,
const
char
*
result_filename
)
{
const
char
*
result_filename
)
{
API_BEGIN
();
API_BEGIN
();
auto
param
=
ConfigBase
::
Str2Map
(
parameter
);
OverallConfig
config
;
config
.
Set
(
param
);
if
(
config
.
num_threads
>
0
)
{
omp_set_num_threads
(
config
.
num_threads
);
}
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
data_filename
,
data_has_header
,
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
data_filename
,
data_has_header
,
parameter
,
result_filename
);
config
.
io_config
,
result_filename
);
API_END
();
API_END
();
}
}
...
@@ -1014,11 +1015,17 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
...
@@ -1014,11 +1015,17 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t
*
out_len
,
int64_t
*
out_len
,
double
*
out_result
)
{
double
*
out_result
)
{
API_BEGIN
();
API_BEGIN
();
auto
param
=
ConfigBase
::
Str2Map
(
parameter
);
OverallConfig
config
;
config
.
Set
(
param
);
if
(
config
.
num_threads
>
0
)
{
omp_set_num_threads
(
config
.
num_threads
);
}
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
int
nrow
=
static_cast
<
int
>
(
nindptr
-
1
);
int
nrow
=
static_cast
<
int
>
(
nindptr
-
1
);
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
nrow
,
get_row_fun
,
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
nrow
,
get_row_fun
,
parameter
,
out_result
,
out_len
);
config
.
io_config
,
out_result
,
out_len
);
API_END
();
API_END
();
}
}
...
@@ -1038,23 +1045,38 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
...
@@ -1038,23 +1045,38 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
double
*
out_result
)
{
double
*
out_result
)
{
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
auto
param
=
ConfigBase
::
Str2Map
(
parameter
);
OverallConfig
config
;
config
.
Set
(
param
);
if
(
config
.
num_threads
>
0
)
{
omp_set_num_threads
(
config
.
num_threads
);
}
int
num_threads
=
1
;
#pragma omp parallel
#pragma omp master
{
num_threads
=
omp_get_num_threads
();
}
int
ncol
=
static_cast
<
int
>
(
ncol_ptr
-
1
);
int
ncol
=
static_cast
<
int
>
(
ncol_ptr
-
1
);
std
::
vector
<
CSC_RowIterator
>
iterators
;
std
::
vector
<
std
::
vector
<
CSC_RowIterator
>>
iterators
(
num_threads
,
std
::
vector
<
CSC_RowIterator
>
());
for
(
int
j
=
0
;
j
<
ncol
;
++
j
)
{
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
iterators
.
emplace_back
(
col_ptr
,
col_ptr_type
,
indices
,
data
,
data_type
,
ncol_ptr
,
nelem
,
j
);
for
(
int
j
=
0
;
j
<
ncol
;
++
j
)
{
iterators
[
i
].
emplace_back
(
col_ptr
,
col_ptr_type
,
indices
,
data
,
data_type
,
ncol_ptr
,
nelem
,
j
);
}
}
}
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
int
row_idx
)
>
get_row_fun
=
std
::
function
<
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
int
row_idx
)
>
get_row_fun
=
[
&
iterators
,
ncol
]
(
int
i
)
{
[
&
iterators
,
ncol
]
(
int
i
)
{
std
::
vector
<
std
::
pair
<
int
,
double
>>
one_row
;
std
::
vector
<
std
::
pair
<
int
,
double
>>
one_row
;
const
int
tid
=
omp_get_thread_num
();
for
(
int
j
=
0
;
j
<
ncol
;
++
j
)
{
for
(
int
j
=
0
;
j
<
ncol
;
++
j
)
{
auto
val
=
iterators
[
j
].
Get
(
i
);
auto
val
=
iterators
[
tid
][
j
].
Get
(
i
);
if
(
std
::
fabs
(
val
)
>
kEpsilon
)
{
if
(
std
::
fabs
(
val
)
>
kEpsilon
)
{
one_row
.
emplace_back
(
j
,
val
);
one_row
.
emplace_back
(
j
,
val
);
}
}
}
}
return
one_row
;
return
one_row
;
};
};
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
static_cast
<
int
>
(
num_row
),
get_row_fun
,
parameter
,
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
static_cast
<
int
>
(
num_row
),
get_row_fun
,
config
.
io_config
,
out_result
,
out_len
);
out_result
,
out_len
);
API_END
();
API_END
();
}
}
...
@@ -1071,10 +1093,16 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
...
@@ -1071,10 +1093,16 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
int64_t
*
out_len
,
int64_t
*
out_len
,
double
*
out_result
)
{
double
*
out_result
)
{
API_BEGIN
();
API_BEGIN
();
auto
param
=
ConfigBase
::
Str2Map
(
parameter
);
OverallConfig
config
;
config
.
Set
(
param
);
if
(
config
.
num_threads
>
0
)
{
omp_set_num_threads
(
config
.
num_threads
);
}
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
auto
get_row_fun
=
RowPairFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
auto
get_row_fun
=
RowPairFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
nrow
,
get_row_fun
,
ref_booster
->
Predict
(
num_iteration
,
predict_type
,
nrow
,
get_row_fun
,
parameter
,
out_result
,
out_len
);
config
.
io_config
,
out_result
,
out_len
);
API_END
();
API_END
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment