Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
df358b2d
Commit
df358b2d
authored
Jan 19, 2017
by
cbecker
Committed by
Guolin Ke
Jan 19, 2017
Browse files
Added success return value to LoadFileToBoosting and SaveModelToFile (#234)
parent
82fcfa0e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
14 deletions
+23
-14
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+5
-3
src/boosting/boosting.cpp
src/boosting/boosting.cpp
+5
-2
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+11
-7
src/boosting/gbdt.h
src/boosting/gbdt.h
+2
-2
No files found.
include/LightGBM/boosting.h
View file @
df358b2d
...
@@ -143,14 +143,16 @@ public:
...
@@ -143,14 +143,16 @@ public:
* \param num_used_model Number of model that want to save, -1 means save all
* \param num_used_model Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \param filename Filename that want to save to
* \return true if succeeded
*/
*/
virtual
void
SaveModelToFile
(
int
num_iterations
,
const
char
*
filename
)
const
=
0
;
virtual
bool
SaveModelToFile
(
int
num_iterations
,
const
char
*
filename
)
const
=
0
;
/*!
/*!
* \brief Restore from a serialized string
* \brief Restore from a serialized string
* \param model_str The string of model
* \param model_str The string of model
* \return true if succeeded
*/
*/
virtual
void
LoadModelFromString
(
const
std
::
string
&
model_str
)
=
0
;
virtual
bool
LoadModelFromString
(
const
std
::
string
&
model_str
)
=
0
;
/*!
/*!
* \brief Get max feature index of this model
* \brief Get max feature index of this model
...
@@ -192,7 +194,7 @@ public:
...
@@ -192,7 +194,7 @@ public:
/*! \brief Disable copy */
/*! \brief Disable copy */
Boosting
(
const
Boosting
&
)
=
delete
;
Boosting
(
const
Boosting
&
)
=
delete
;
static
void
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
);
static
bool
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
);
/*!
/*!
* \brief Create boosting object
* \brief Create boosting object
...
...
src/boosting/boosting.cpp
View file @
df358b2d
...
@@ -10,7 +10,7 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
...
@@ -10,7 +10,7 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return
type
;
return
type
;
}
}
void
Boosting
::
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
)
{
bool
Boosting
::
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
)
{
if
(
boosting
!=
nullptr
)
{
if
(
boosting
!=
nullptr
)
{
TextReader
<
size_t
>
model_reader
(
filename
,
true
);
TextReader
<
size_t
>
model_reader
(
filename
,
true
);
model_reader
.
ReadAllLines
();
model_reader
.
ReadAllLines
();
...
@@ -18,8 +18,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
...
@@ -18,8 +18,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
for
(
auto
&
line
:
model_reader
.
Lines
())
{
for
(
auto
&
line
:
model_reader
.
Lines
())
{
str_buf
<<
line
<<
'\n'
;
str_buf
<<
line
<<
'\n'
;
}
}
boosting
->
LoadModelFromString
(
str_buf
.
str
());
if
(
!
boosting
->
LoadModelFromString
(
str_buf
.
str
()))
return
false
;
}
}
return
true
;
}
}
Boosting
*
Boosting
::
CreateBoosting
(
const
std
::
string
&
type
,
const
char
*
filename
)
{
Boosting
*
Boosting
::
CreateBoosting
(
const
std
::
string
&
type
,
const
char
*
filename
)
{
...
...
src/boosting/gbdt.cpp
View file @
df358b2d
...
@@ -509,7 +509,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
...
@@ -509,7 +509,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
return
str_buf
.
str
();
return
str_buf
.
str
();
}
}
void
GBDT
::
SaveModelToFile
(
int
num_iteration
,
const
char
*
filename
)
const
{
bool
GBDT
::
SaveModelToFile
(
int
num_iteration
,
const
char
*
filename
)
const
{
/*! \brief File to write models */
/*! \brief File to write models */
std
::
ofstream
output_file
;
std
::
ofstream
output_file
;
output_file
.
open
(
filename
);
output_file
.
open
(
filename
);
...
@@ -553,9 +553,11 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
...
@@ -553,9 +553,11 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
}
}
output_file
.
close
();
output_file
.
close
();
return
(
bool
)
output_file
;
}
}
void
GBDT
::
LoadModelFromString
(
const
std
::
string
&
model_str
)
{
bool
GBDT
::
LoadModelFromString
(
const
std
::
string
&
model_str
)
{
// use serialized string to restore this object
// use serialized string to restore this object
models_
.
clear
();
models_
.
clear
();
std
::
vector
<
std
::
string
>
lines
=
Common
::
Split
(
model_str
.
c_str
(),
'\n'
);
std
::
vector
<
std
::
string
>
lines
=
Common
::
Split
(
model_str
.
c_str
(),
'\n'
);
...
@@ -566,7 +568,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -566,7 +568,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common
::
Atoi
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
&
num_class_
);
Common
::
Atoi
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
&
num_class_
);
}
else
{
}
else
{
Log
::
Fatal
(
"Model file doesn't specify the number of classes"
);
Log
::
Fatal
(
"Model file doesn't specify the number of classes"
);
return
;
return
false
;
}
}
// get index of label
// get index of label
line
=
Common
::
FindFromLines
(
lines
,
"label_index="
);
line
=
Common
::
FindFromLines
(
lines
,
"label_index="
);
...
@@ -574,7 +576,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -574,7 +576,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common
::
Atoi
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
&
label_idx_
);
Common
::
Atoi
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
&
label_idx_
);
}
else
{
}
else
{
Log
::
Fatal
(
"Model file doesn't specify the label index"
);
Log
::
Fatal
(
"Model file doesn't specify the label index"
);
return
;
return
false
;
}
}
// get max_feature_idx first
// get max_feature_idx first
line
=
Common
::
FindFromLines
(
lines
,
"max_feature_idx="
);
line
=
Common
::
FindFromLines
(
lines
,
"max_feature_idx="
);
...
@@ -582,7 +584,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -582,7 +584,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
Common
::
Atoi
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
&
max_feature_idx_
);
Common
::
Atoi
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
&
max_feature_idx_
);
}
else
{
}
else
{
Log
::
Fatal
(
"Model file doesn't specify max_feature_idx"
);
Log
::
Fatal
(
"Model file doesn't specify max_feature_idx"
);
return
;
return
false
;
}
}
// get sigmoid parameter
// get sigmoid parameter
line
=
Common
::
FindFromLines
(
lines
,
"sigmoid="
);
line
=
Common
::
FindFromLines
(
lines
,
"sigmoid="
);
...
@@ -597,11 +599,11 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -597,11 +599,11 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
feature_names_
=
Common
::
Split
(
line
.
substr
(
std
::
strlen
(
"feature_names="
)).
c_str
(),
" "
);
feature_names_
=
Common
::
Split
(
line
.
substr
(
std
::
strlen
(
"feature_names="
)).
c_str
(),
" "
);
if
(
feature_names_
.
size
()
!=
static_cast
<
size_t
>
(
max_feature_idx_
+
1
))
{
if
(
feature_names_
.
size
()
!=
static_cast
<
size_t
>
(
max_feature_idx_
+
1
))
{
Log
::
Fatal
(
"Wrong size of feature_names"
);
Log
::
Fatal
(
"Wrong size of feature_names"
);
return
;
return
false
;
}
}
}
else
{
}
else
{
Log
::
Fatal
(
"Model file doesn't contain feature names"
);
Log
::
Fatal
(
"Model file doesn't contain feature names"
);
return
;
return
false
;
}
}
// get tree models
// get tree models
...
@@ -624,6 +626,8 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -624,6 +626,8 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
num_iteration_for_pred_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_class_
;
num_iteration_for_pred_
=
static_cast
<
int
>
(
models_
.
size
())
/
num_class_
;
num_init_iteration_
=
num_iteration_for_pred_
;
num_init_iteration_
=
num_iteration_for_pred_
;
iter_
=
0
;
iter_
=
0
;
return
true
;
}
}
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
GBDT
::
FeatureImportance
()
const
{
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
GBDT
::
FeatureImportance
()
const
{
...
...
src/boosting/gbdt.h
View file @
df358b2d
...
@@ -156,12 +156,12 @@ public:
...
@@ -156,12 +156,12 @@ public:
* \param is_finish Is training finished or not
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \param filename Filename that want to save to
*/
*/
virtual
void
SaveModelToFile
(
int
num_iterations
,
const
char
*
filename
)
const
override
;
virtual
bool
SaveModelToFile
(
int
num_iterations
,
const
char
*
filename
)
const
override
;
/*!
/*!
* \brief Restore from a serialized string
* \brief Restore from a serialized string
*/
*/
void
LoadModelFromString
(
const
std
::
string
&
model_str
)
override
;
bool
LoadModelFromString
(
const
std
::
string
&
model_str
)
override
;
/*!
/*!
* \brief Get max feature index of this model
* \brief Get max feature index of this model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment