Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
962b7eb0
Commit
962b7eb0
authored
Nov 26, 2016
by
Guolin Ke
Browse files
change to std::lock_guard
parent
3484e898
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
24 deletions
+11
-24
src/c_api.cpp
src/c_api.cpp
+11
-24
No files found.
src/c_api.cpp
View file @
962b7eb0
...
...
@@ -30,7 +30,7 @@ public:
Booster
(
const
Dataset
*
train_data
,
const
char
*
parameters
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
config_
.
Set
(
param
);
// create boosting
...
...
@@ -43,13 +43,11 @@ public:
// initialize the boosting
boosting_
->
Init
(
&
config_
.
boosting_config
,
train_data
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
lock
.
unlock
();
}
void
MergeFrom
(
const
Booster
*
other
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
boosting_
->
MergeFrom
(
other
->
boosting_
.
get
());
lock
.
unlock
();
}
~
Booster
()
{
...
...
@@ -57,17 +55,16 @@ public:
}
void
ResetTrainingData
(
const
Dataset
*
train_data
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
train_data_
=
train_data
;
ConstructObjectAndTrainingMetrics
(
train_data_
);
// initialize the boosting
boosting_
->
ResetTrainingData
(
&
config_
.
boosting_config
,
train_data_
,
objective_fun_
.
get
(),
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
train_metric_
));
lock
.
unlock
();
}
void
ResetConfig
(
const
char
*
parameters
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
if
(
param
.
count
(
"num_class"
))
{
Log
::
Fatal
(
"cannot change num class during training"
);
...
...
@@ -77,11 +74,10 @@ public:
}
config_
.
Set
(
param
);
ResetTrainingData
(
train_data_
);
lock
.
unlock
();
}
void
AddValidData
(
const
Dataset
*
valid_data
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
valid_metrics_
.
emplace_back
();
for
(
auto
metric_type
:
config_
.
metric_types
)
{
auto
metric
=
std
::
unique_ptr
<
Metric
>
(
Metric
::
CreateMetric
(
metric_type
,
config_
.
metric_config
));
...
...
@@ -92,30 +88,24 @@ public:
valid_metrics_
.
back
().
shrink_to_fit
();
boosting_
->
AddValidDataset
(
valid_data
,
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
.
back
()));
lock
.
unlock
();
}
bool
TrainOneIter
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
bool
ret
=
boosting_
->
TrainOneIter
(
nullptr
,
nullptr
,
false
);
lock
.
unlock
();
return
ret
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
boosting_
->
TrainOneIter
(
nullptr
,
nullptr
,
false
);
}
bool
TrainOneIter
(
const
float
*
gradients
,
const
float
*
hessians
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
bool
ret
=
boosting_
->
TrainOneIter
(
gradients
,
hessians
,
false
);
lock
.
unlock
();
return
ret
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
boosting_
->
TrainOneIter
(
gradients
,
hessians
,
false
);
}
void
RollbackOneIter
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
boosting_
->
RollbackOneIter
();
lock
.
unlock
();
}
void
PrepareForPrediction
(
int
num_iteration
,
int
predict_type
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
boosting_
->
SetNumIterationForPred
(
num_iteration
);
bool
is_predict_leaf
=
false
;
bool
is_raw_score
=
false
;
...
...
@@ -127,7 +117,6 @@ public:
is_raw_score
=
false
;
}
predictor_
.
reset
(
new
Predictor
(
boosting_
.
get
(),
is_raw_score
,
is_predict_leaf
));
lock
.
unlock
();
}
void
GetPredictAt
(
int
data_idx
,
score_t
*
out_result
,
data_size_t
*
out_len
)
{
...
...
@@ -143,9 +132,7 @@ public:
}
void
SaveModelToFile
(
int
num_iteration
,
const
char
*
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
boosting_
->
SaveModelToFile
(
num_iteration
,
filename
);
lock
.
unlock
();
}
int
GetEvalCounts
()
const
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment