Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
dd316895
Commit
dd316895
authored
Jan 10, 2017
by
Guolin Ke
Browse files
fix name of boosting type
parent
76c44d78
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
28 additions
and
43 deletions
+28
-43
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+3
-3
include/LightGBM/config.h
include/LightGBM/config.h
+1
-7
src/application/application.cpp
src/application/application.cpp
+7
-6
src/boosting/boosting.cpp
src/boosting/boosting.cpp
+12
-17
src/boosting/dart.hpp
src/boosting/dart.hpp
+0
-5
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+2
-2
src/boosting/gbdt.h
src/boosting/gbdt.h
+1
-1
src/io/config.cpp
src/io/config.cpp
+2
-2
No files found.
include/LightGBM/boosting.h
View file @
dd316895
...
@@ -182,9 +182,9 @@ public:
...
@@ -182,9 +182,9 @@ public:
virtual
void
SetNumIterationForPred
(
int
num_iteration
)
=
0
;
virtual
void
SetNumIterationForPred
(
int
num_iteration
)
=
0
;
/*!
/*!
* \brief
Get Type name of this boosting object
* \brief
Name of submodel
*/
*/
virtual
const
char
*
Name
()
const
=
0
;
virtual
const
char
*
SubModel
Name
()
const
=
0
;
Boosting
()
=
default
;
Boosting
()
=
default
;
/*! \brief Disable copy */
/*! \brief Disable copy */
...
@@ -201,7 +201,7 @@ public:
...
@@ -201,7 +201,7 @@ public:
* \param filename name of model file, if existing will continue to train from this model
* \param filename name of model file, if existing will continue to train from this model
* \return The boosting object
* \return The boosting object
*/
*/
static
Boosting
*
CreateBoosting
(
BoostingType
type
,
const
char
*
filename
);
static
Boosting
*
CreateBoosting
(
const
std
::
string
&
type
,
const
char
*
filename
);
/*!
/*!
* \brief Create boosting object from model file
* \brief Create boosting object from model file
...
...
include/LightGBM/config.h
View file @
dd316895
...
@@ -76,12 +76,6 @@ public:
...
@@ -76,12 +76,6 @@ public:
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Str2Map
(
const
char
*
parameters
);
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
Str2Map
(
const
char
*
parameters
);
};
};
/*! \brief Types of boosting */
enum
BoostingType
{
kGBDT
,
kDART
,
kUnknow
};
/*! \brief Types of tasks */
/*! \brief Types of tasks */
enum
TaskType
{
enum
TaskType
{
kTrain
,
kPredict
kTrain
,
kPredict
...
@@ -240,7 +234,7 @@ public:
...
@@ -240,7 +234,7 @@ public:
bool
is_parallel
=
false
;
bool
is_parallel
=
false
;
bool
is_parallel_find_bin
=
false
;
bool
is_parallel_find_bin
=
false
;
IOConfig
io_config
;
IOConfig
io_config
;
Boo
sting
Type
boosting_type
=
BoostingType
::
kGBDT
;
std
::
st
r
ing
boosting_type
=
"gbdt"
;
BoostingConfig
boosting_config
;
BoostingConfig
boosting_config
;
std
::
string
objective_type
=
"regression"
;
std
::
string
objective_type
=
"regression"
;
ObjectiveConfig
objective_config
;
ObjectiveConfig
objective_config
;
...
...
src/application/application.cpp
View file @
dd316895
...
@@ -190,13 +190,14 @@ void Application::InitTrain() {
...
@@ -190,13 +190,14 @@ void Application::InitTrain() {
Network
::
Init
(
config_
.
network_config
);
Network
::
Init
(
config_
.
network_config
);
Log
::
Info
(
"Finished initializing network"
);
Log
::
Info
(
"Finished initializing network"
);
// sync global random seed for feature patition
// sync global random seed for feature patition
if
(
config_
.
boosting_
type
==
BoostingType
::
kGBDT
||
config_
.
boosting_type
==
BoostingType
::
kDART
)
{
config_
.
boosting_
config
.
tree_config
.
feature_fraction_seed
=
config_
.
boosting_config
.
tree_config
.
feature_fraction_seed
=
GlobalSyncUpByMin
<
int
>
(
config_
.
boosting_config
.
tree_config
.
feature_fraction_seed
);
GlobalSyncUpByMin
<
int
>
(
config_
.
boosting_config
.
tree_config
.
feature_fraction
_seed
);
config_
.
boosting_config
.
tree_config
.
feature_fraction
=
config_
.
boosting_config
.
tree_config
.
feature_fraction
=
GlobalSyncUpByMin
<
double
>
(
config_
.
boosting_config
.
tree_config
.
feature_fraction
);
GlobalSyncUpByMin
<
double
>
(
config_
.
boosting_config
.
tree_config
.
feature_fraction
);
config_
.
boosting_config
.
drop_seed
=
}
GlobalSyncUpByMin
<
int
>
(
config_
.
boosting_config
.
drop_seed
);
}
}
// create boosting
// create boosting
boosting_
.
reset
(
boosting_
.
reset
(
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
...
...
src/boosting/boosting.cpp
View file @
dd316895
...
@@ -4,15 +4,10 @@
...
@@ -4,15 +4,10 @@
namespace
LightGBM
{
namespace
LightGBM
{
Boo
sting
Type
GetBoostingTypeFromModelFile
(
const
char
*
filename
)
{
std
::
st
r
ing
GetBoostingTypeFromModelFile
(
const
char
*
filename
)
{
TextReader
<
size_t
>
model_reader
(
filename
,
true
);
TextReader
<
size_t
>
model_reader
(
filename
,
true
);
std
::
string
type
=
model_reader
.
first_line
();
std
::
string
type
=
model_reader
.
first_line
();
if
(
type
==
std
::
string
(
"gbdt"
))
{
return
type
;
return
BoostingType
::
kGBDT
;
}
else
if
(
type
==
std
::
string
(
"dart"
))
{
return
BoostingType
::
kDART
;
}
return
BoostingType
::
kUnknow
;
}
}
void
Boosting
::
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
)
{
void
Boosting
::
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
)
{
...
@@ -27,11 +22,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
...
@@ -27,11 +22,11 @@ void Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
}
}
}
}
Boosting
*
Boosting
::
CreateBoosting
(
BoostingType
type
,
const
char
*
filename
)
{
Boosting
*
Boosting
::
CreateBoosting
(
const
std
::
string
&
type
,
const
char
*
filename
)
{
if
(
filename
==
nullptr
||
filename
[
0
]
==
'\0'
)
{
if
(
filename
==
nullptr
||
filename
[
0
]
==
'\0'
)
{
if
(
type
==
BoostingType
::
kGBDT
)
{
if
(
type
==
std
::
string
(
"gbdt"
)
)
{
return
new
GBDT
();
return
new
GBDT
();
}
else
if
(
type
==
BoostingType
::
kDART
)
{
}
else
if
(
type
==
std
::
string
(
"dart"
)
)
{
return
new
DART
();
return
new
DART
();
}
else
{
}
else
{
return
nullptr
;
return
nullptr
;
...
@@ -39,15 +34,15 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
...
@@ -39,15 +34,15 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
}
else
{
}
else
{
std
::
unique_ptr
<
Boosting
>
ret
;
std
::
unique_ptr
<
Boosting
>
ret
;
auto
type_in_file
=
GetBoostingTypeFromModelFile
(
filename
);
auto
type_in_file
=
GetBoostingTypeFromModelFile
(
filename
);
if
(
type_in_file
==
type
)
{
if
(
type_in_file
==
std
::
string
(
"tree"
)
)
{
if
(
type
==
BoostingType
::
kGBDT
)
{
if
(
type
==
std
::
string
(
"gbdt"
)
)
{
ret
.
reset
(
new
GBDT
());
ret
.
reset
(
new
GBDT
());
}
else
if
(
type
==
BoostingType
::
kDART
)
{
}
else
if
(
type
==
std
::
string
(
"dart"
)
)
{
ret
.
reset
(
new
DART
());
ret
.
reset
(
new
DART
());
}
}
LoadFileToBoosting
(
ret
.
get
(),
filename
);
LoadFileToBoosting
(
ret
.
get
(),
filename
);
}
else
{
}
else
{
Log
::
Fatal
(
"
Boosting type in parameter is not the same as the
type in
the
model file
"
);
Log
::
Fatal
(
"
unknow submodel
type in model file
%s"
,
filename
);
}
}
return
ret
.
release
();
return
ret
.
release
();
}
}
...
@@ -56,10 +51,10 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
...
@@ -56,10 +51,10 @@ Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
Boosting
*
Boosting
::
CreateBoosting
(
const
char
*
filename
)
{
Boosting
*
Boosting
::
CreateBoosting
(
const
char
*
filename
)
{
auto
type
=
GetBoostingTypeFromModelFile
(
filename
);
auto
type
=
GetBoostingTypeFromModelFile
(
filename
);
std
::
unique_ptr
<
Boosting
>
ret
;
std
::
unique_ptr
<
Boosting
>
ret
;
if
(
type
==
Boo
sting
Type
::
kGBDT
)
{
if
(
type
==
std
::
st
r
ing
(
"tree"
)
)
{
ret
.
reset
(
new
GBDT
());
ret
.
reset
(
new
GBDT
());
}
else
if
(
type
==
BoostingType
::
kDART
)
{
}
else
{
ret
.
reset
(
new
DART
()
);
Log
::
Fatal
(
"unknow submodel type in model file %s"
,
filename
);
}
}
LoadFileToBoosting
(
ret
.
get
(),
filename
);
LoadFileToBoosting
(
ret
.
get
(),
filename
);
return
ret
.
release
();
return
ret
.
release
();
...
...
src/boosting/dart.hpp
View file @
dd316895
...
@@ -72,11 +72,6 @@ public:
...
@@ -72,11 +72,6 @@ public:
return
train_score_updater_
->
score
();
return
train_score_updater_
->
score
();
}
}
/*!
* \brief Get Type name of this boosting object
*/
const
char
*
Name
()
const
override
{
return
"dart"
;
}
private:
private:
/*!
/*!
* \brief drop trees based on drop_rate
* \brief drop trees based on drop_rate
...
...
src/boosting/gbdt.cpp
View file @
dd316895
...
@@ -439,7 +439,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
...
@@ -439,7 +439,7 @@ std::string GBDT::DumpModel(int num_iteration) const {
std
::
stringstream
str_buf
;
std
::
stringstream
str_buf
;
str_buf
<<
"{"
;
str_buf
<<
"{"
;
str_buf
<<
"
\"
name
\"
:
\"
"
<<
Name
()
<<
"
\"
,"
<<
std
::
endl
;
str_buf
<<
"
\"
name
\"
:
\"
"
<<
SubModel
Name
()
<<
"
\"
,"
<<
std
::
endl
;
str_buf
<<
"
\"
num_class
\"
:"
<<
num_class_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
num_class
\"
:"
<<
num_class_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
label_index
\"
:"
<<
label_idx_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
label_index
\"
:"
<<
label_idx_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
max_feature_idx
\"
:"
<<
max_feature_idx_
<<
","
<<
std
::
endl
;
str_buf
<<
"
\"
max_feature_idx
\"
:"
<<
max_feature_idx_
<<
","
<<
std
::
endl
;
...
@@ -481,7 +481,7 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
...
@@ -481,7 +481,7 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
std
::
ofstream
output_file
;
std
::
ofstream
output_file
;
output_file
.
open
(
filename
);
output_file
.
open
(
filename
);
// output model type
// output model type
output_file
<<
Name
()
<<
std
::
endl
;
output_file
<<
SubModel
Name
()
<<
std
::
endl
;
// output number of class
// output number of class
output_file
<<
"num_class="
<<
num_class_
<<
std
::
endl
;
output_file
<<
"num_class="
<<
num_class_
<<
std
::
endl
;
// output label index
// output label index
...
...
src/boosting/gbdt.h
View file @
dd316895
...
@@ -212,7 +212,7 @@ public:
...
@@ -212,7 +212,7 @@ public:
/*!
/*!
* \brief Get Type name of this boosting object
* \brief Get Type name of this boosting object
*/
*/
virtual
const
char
*
Name
()
const
override
{
return
"
gbdt
"
;
}
virtual
const
char
*
SubModel
Name
()
const
override
{
return
"
tree
"
;
}
protected:
protected:
/*!
/*!
...
...
src/io/config.cpp
View file @
dd316895
...
@@ -76,9 +76,9 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
...
@@ -76,9 +76,9 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
if
(
GetString
(
params
,
"boosting_type"
,
&
value
))
{
if
(
GetString
(
params
,
"boosting_type"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"gbdt"
)
||
value
==
std
::
string
(
"gbrt"
))
{
if
(
value
==
std
::
string
(
"gbdt"
)
||
value
==
std
::
string
(
"gbrt"
))
{
boosting_type
=
BoostingType
::
kGBDT
;
boosting_type
=
"gbdt"
;
}
else
if
(
value
==
std
::
string
(
"dart"
))
{
}
else
if
(
value
==
std
::
string
(
"dart"
))
{
boosting_type
=
BoostingType
::
kDART
;
boosting_type
=
"dart"
;
}
else
{
}
else
{
Log
::
Fatal
(
"Unknown boosting type %s"
,
value
.
c_str
());
Log
::
Fatal
(
"Unknown boosting type %s"
,
value
.
c_str
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment