Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
cccd0587
Commit
cccd0587
authored
Mar 31, 2017
by
Guolin Ke
Browse files
skip the training of empty class in classification.
parent
e404d7cf
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
3 deletions
+55
-3
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+48
-3
src/boosting/gbdt.h
src/boosting/gbdt.h
+2
-0
src/objective/binary_objective.hpp
src/objective/binary_objective.hpp
+5
-0
No files found.
src/boosting/gbdt.cpp
View file @
cccd0587
...
@@ -162,6 +162,37 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
...
@@ -162,6 +162,37 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
if
(
train_data_
!=
nullptr
)
{
if
(
train_data_
!=
nullptr
)
{
// reset config for tree learner
// reset config for tree learner
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
tree_learner_
->
ResetConfig
(
&
new_config
->
tree_config
);
class_need_train_
=
std
::
vector
<
bool
>
(
num_class_
,
true
);
if
(
num_class_
>
1
||
sigmoid_
>
0
)
{
// + 1 here for the binary classification
class_default_output_
=
std
::
vector
<
double
>
(
num_class_
+
1
,
0.0
f
);
std
::
vector
<
data_size_t
>
cnt_per_class
(
num_class_
,
0
);
auto
label
=
train_data_
->
metadata
().
label
();
for
(
int
i
=
0
;
i
<
num_data_
;
++
i
)
{
++
cnt_per_class
[
static_cast
<
int
>
(
label
[
i
])];
}
if
(
num_class_
>
1
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
if
(
cnt_per_class
[
i
]
==
num_data_
)
{
Log
::
Warning
(
"Only contain one class."
);
class_need_train_
[
i
]
=
false
;
class_default_output_
[
i
]
=
-
std
::
log
(
kEpsilon
);
}
else
if
(
cnt_per_class
[
i
]
==
0
)
{
class_need_train_
[
i
]
=
false
;
class_default_output_
[
i
]
=
-
std
::
log
(
1.0
f
/
kEpsilon
-
1.0
f
);
}
}
}
else
{
// binary classification.
if
(
cnt_per_class
[
1
]
==
0
)
{
class_need_train_
[
0
]
=
false
;
class_default_output_
[
0
]
=
-
std
::
log
(
1.0
f
/
kEpsilon
-
1.0
f
);
}
else
if
(
cnt_per_class
[
1
]
==
num_data_
)
{
class_need_train_
[
0
]
=
false
;
class_default_output_
[
0
]
=
-
std
::
log
(
kEpsilon
);
}
}
}
}
}
gbdt_config_
.
reset
(
new_config
.
release
());
gbdt_config_
.
reset
(
new_config
.
release
());
}
}
...
@@ -370,8 +401,11 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -370,8 +401,11 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
#ifdef TIMETAG
#ifdef TIMETAG
start_time
=
std
::
chrono
::
steady_clock
::
now
();
start_time
=
std
::
chrono
::
steady_clock
::
now
();
#endif
#endif
std
::
unique_ptr
<
Tree
>
new_tree
(
std
::
unique_ptr
<
Tree
>
new_tree
(
new
Tree
(
2
));
if
(
class_need_train_
[
curr_class
])
{
new_tree
.
reset
(
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
tree_learner_
->
Train
(
gradient
+
curr_class
*
num_data_
,
hessian
+
curr_class
*
num_data_
));
}
#ifdef TIMETAG
#ifdef TIMETAG
tree_time
+=
std
::
chrono
::
steady_clock
::
now
()
-
start_time
;
tree_time
+=
std
::
chrono
::
steady_clock
::
now
()
-
start_time
;
#endif
#endif
...
@@ -383,6 +417,17 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
...
@@ -383,6 +417,17 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// update score
// update score
UpdateScore
(
new_tree
.
get
(),
curr_class
);
UpdateScore
(
new_tree
.
get
(),
curr_class
);
UpdateScoreOutOfBag
(
new_tree
.
get
(),
curr_class
);
UpdateScoreOutOfBag
(
new_tree
.
get
(),
curr_class
);
}
else
{
// only add default score one-time
if
(
!
class_need_train_
[
curr_class
]
&&
models_
.
size
()
<
num_class_
)
{
auto
output
=
class_default_output_
[
curr_class
];
new_tree
->
Split
(
0
,
0
,
BinType
::
NumericalBin
,
0
,
0
,
0
,
output
,
output
,
0
,
num_data_
,
1
);
train_score_updater_
->
AddScore
(
output
,
curr_class
);
for
(
auto
&
score_updater
:
valid_score_updater_
)
{
score_updater
->
AddScore
(
output
,
curr_class
);
}
}
}
}
// add model
// add model
models_
.
push_back
(
std
::
move
(
new_tree
));
models_
.
push_back
(
std
::
move
(
new_tree
));
...
...
src/boosting/gbdt.h
View file @
cccd0587
...
@@ -343,6 +343,8 @@ protected:
...
@@ -343,6 +343,8 @@ protected:
std
::
unique_ptr
<
Dataset
>
tmp_subset_
;
std
::
unique_ptr
<
Dataset
>
tmp_subset_
;
bool
is_use_subset_
;
bool
is_use_subset_
;
bool
boost_from_average_
;
bool
boost_from_average_
;
std
::
vector
<
bool
>
class_need_train_
;
std
::
vector
<
double
>
class_default_output_
;
};
};
}
// namespace LightGBM
}
// namespace LightGBM
...
...
src/objective/binary_objective.hpp
View file @
cccd0587
...
@@ -42,6 +42,11 @@ public:
...
@@ -42,6 +42,11 @@ public:
++
cnt_negative
;
++
cnt_negative
;
}
}
}
}
if
(
cnt_negative
==
0
||
cnt_positive
==
0
)
{
Log
::
Warning
(
"Only contain one class."
);
// not need to boost.
num_data_
=
0
;
}
Log
::
Info
(
"Number of positive: %d, number of negative: %d"
,
cnt_positive
,
cnt_negative
);
Log
::
Info
(
"Number of positive: %d, number of negative: %d"
,
cnt_positive
,
cnt_negative
);
// use -1 for negative class, and 1 for positive class
// use -1 for negative class, and 1 for positive class
label_val_
[
0
]
=
-
1
;
label_val_
[
0
]
=
-
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment