Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
a6f47d00
Commit
a6f47d00
authored
Jan 09, 2017
by
Guolin Ke
Browse files
use std::string for tree_learner_type.
parent
9b2558d6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
23 deletions
+19
-23
include/LightGBM/config.h
include/LightGBM/config.h
+1
-7
include/LightGBM/tree_learner.h
include/LightGBM/tree_learner.h
+1
-1
src/io/config.cpp
src/io/config.cpp
+12
-10
src/treelearner/tree_learner.cpp
src/treelearner/tree_learner.cpp
+5
-5
No files found.
include/LightGBM/config.h
View file @
a6f47d00
...
...
@@ -187,12 +187,6 @@ public:
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
};
/*! \brief Types of tree learning algorithms */
enum
TreeLearnerType
{
kSerialTreeLearner
,
kFeatureParallelTreelearner
,
kDataParallelTreeLearner
,
KVotingParallelTreeLearner
};
/*! \brief Config for Boosting */
struct
BoostingConfig
:
public
ConfigBase
{
public:
...
...
@@ -213,7 +207,7 @@ public:
bool
xgboost_dart_mode
=
false
;
bool
uniform_drop
=
false
;
int
drop_seed
=
4
;
TreeLearnerType
tree_learner_type
=
TreeLearnerType
::
kSerialTreeLearner
;
std
::
string
tree_learner_type
=
"serial"
;
TreeConfig
tree_config
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
private:
...
...
include/LightGBM/tree_learner.h
View file @
a6f47d00
...
...
@@ -66,7 +66,7 @@ public:
* \param type Type of tree learner
* \param tree_config config of tree
*/
static
TreeLearner
*
CreateTreeLearner
(
TreeLearnerType
type
,
static
TreeLearner
*
CreateTreeLearner
(
const
std
::
string
&
type
,
const
TreeConfig
*
tree_config
);
};
...
...
src/io/config.cpp
View file @
a6f47d00
...
...
@@ -159,20 +159,22 @@ void OverallConfig::CheckParamConflict() {
is_parallel
=
true
;
}
else
{
is_parallel
=
false
;
boosting_config
.
tree_learner_type
=
TreeLearnerType
::
kSerialTreeLearner
;
boosting_config
.
tree_learner_type
=
"serial"
;
}
if
(
boosting_config
.
tree_learner_type
==
TreeLearnerType
::
kSerialTreeLearner
)
{
if
(
boosting_config
.
tree_learner_type
==
std
::
string
(
"serial"
)
)
{
is_parallel
=
false
;
network_config
.
num_machines
=
1
;
}
if
(
boosting_config
.
tree_learner_type
==
TreeLearnerType
::
kSerialTreeLearner
||
boosting_config
.
tree_learner_type
==
TreeLearnerType
::
kFeatureParallelTreelearner
)
{
if
(
boosting_config
.
tree_learner_type
==
std
::
string
(
"serial"
)
||
boosting_config
.
tree_learner_type
==
std
::
string
(
"feature"
)
)
{
is_parallel_find_bin
=
false
;
}
else
if
(
boosting_config
.
tree_learner_type
==
TreeLearnerType
::
kDataParallelTreeLearner
)
{
}
else
if
(
boosting_config
.
tree_learner_type
==
std
::
string
(
"data"
)
||
boosting_config
.
tree_learner_type
==
std
::
string
(
"voting"
))
{
is_parallel_find_bin
=
true
;
if
(
boosting_config
.
tree_config
.
histogram_pool_size
>=
0
)
{
if
(
boosting_config
.
tree_config
.
histogram_pool_size
>=
0
&&
boosting_config
.
tree_learner_type
==
std
::
string
(
"data"
))
{
Log
::
Warning
(
"Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
,
boosting_config
.
tree_config
.
histogram_pool_size
);
// Change pool size to -1 (not limit) when using data parallel to reduce communication costs
...
...
@@ -326,13 +328,13 @@ void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, st
if
(
GetString
(
params
,
"tree_learner"
,
&
value
))
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
Common
::
tolower
);
if
(
value
==
std
::
string
(
"serial"
))
{
tree_learner_type
=
TreeLearnerType
::
kSerialTreeLearner
;
tree_learner_type
=
"serial"
;
}
else
if
(
value
==
std
::
string
(
"feature"
)
||
value
==
std
::
string
(
"feature_parallel"
))
{
tree_learner_type
=
TreeLearnerType
::
kFeatureParallelTreelearner
;
tree_learner_type
=
"feature"
;
}
else
if
(
value
==
std
::
string
(
"data"
)
||
value
==
std
::
string
(
"data_parallel"
))
{
tree_learner_type
=
TreeLearnerType
::
kDataParallelTreeLearner
;
tree_learner_type
=
"data"
;
}
else
if
(
value
==
std
::
string
(
"voting"
)
||
value
==
std
::
string
(
"voting_parallel"
))
{
tree_learner_type
=
TreeLearnerType
::
KVotingParallelTreeLearner
;
tree_learner_type
=
"voting"
;
}
else
{
Log
::
Fatal
(
"Unknown tree learner type %s"
,
value
.
c_str
());
}
...
...
src/treelearner/tree_learner.cpp
View file @
a6f47d00
...
...
@@ -5,14 +5,14 @@
namespace
LightGBM
{
TreeLearner
*
TreeLearner
::
CreateTreeLearner
(
TreeLearnerType
type
,
const
TreeConfig
*
tree_config
)
{
if
(
type
==
TreeLearnerType
::
kSerialTreeLearner
)
{
TreeLearner
*
TreeLearner
::
CreateTreeLearner
(
const
std
::
string
&
type
,
const
TreeConfig
*
tree_config
)
{
if
(
type
==
std
::
string
(
"serial"
)
)
{
return
new
SerialTreeLearner
(
tree_config
);
}
else
if
(
type
==
TreeLearnerType
::
kFeatureParallelTreelearner
)
{
}
else
if
(
type
==
std
::
string
(
"feature"
)
)
{
return
new
FeatureParallelTreeLearner
(
tree_config
);
}
else
if
(
type
==
TreeLearnerType
::
kDataParallelTreeLearner
)
{
}
else
if
(
type
==
std
::
string
(
"data"
)
)
{
return
new
DataParallelTreeLearner
(
tree_config
);
}
else
if
(
type
==
TreeLearnerType
::
KVotingParallelTreeLearner
)
{
}
else
if
(
type
==
std
::
string
(
"voting"
)
)
{
return
new
VotingParallelTreeLearner
(
tree_config
);
}
return
nullptr
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment