Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
841a8987
Commit
841a8987
authored
Mar 29, 2017
by
Guolin Ke
Browse files
support OVA multi-classification.
parent
14195876
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
166 additions
and
41 deletions
+166
-41
src/io/config.cpp
src/io/config.cpp
+11
-2
src/metric/metric.cpp
src/metric/metric.cpp
+3
-1
src/metric/multiclass_metric.hpp
src/metric/multiclass_metric.hpp
+81
-4
src/objective/binary_objective.hpp
src/objective/binary_objective.hpp
+11
-4
src/objective/multiclass_objective.hpp
src/objective/multiclass_objective.hpp
+57
-29
src/objective/objective_function.cpp
src/objective/objective_function.cpp
+3
-1
No files found.
src/io/config.cpp
View file @
841a8987
...
...
@@ -138,7 +138,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
void
OverallConfig
::
CheckParamConflict
()
{
// check if objective_type, metric_type, and num_class match
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
));
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
)
||
objective_type
==
std
::
string
(
"multiclassova"
));
int
num_class_check
=
boosting_config
.
num_class
;
if
(
objective_type_multiclass
)
{
if
(
num_class_check
<=
1
)
{
...
...
@@ -151,11 +152,19 @@ void OverallConfig::CheckParamConflict() {
}
if
(
boosting_config
.
is_provide_training_metric
||
!
io_config
.
valid_data_filenames
.
empty
())
{
for
(
std
::
string
metric_type
:
metric_types
)
{
bool
metric_type_multiclass
=
(
metric_type
==
std
::
string
(
"multi_logloss"
)
||
metric_type
==
std
::
string
(
"multi_error"
));
bool
metric_type_multiclass
=
(
metric_type
==
std
::
string
(
"multi_logloss"
)
||
metric_type
==
std
::
string
(
"multi_error"
)
||
metric_type
==
std
::
string
(
"multi_loglossova"
));
if
((
objective_type_multiclass
&&
!
metric_type_multiclass
)
||
(
!
objective_type_multiclass
&&
metric_type_multiclass
))
{
Log
::
Fatal
(
"Objective and metrics don't match"
);
}
if
(
objective_type
==
std
::
string
(
"multiclassova"
)
&&
metric_type
==
std
::
string
(
"multi_logloss"
))
{
Log
::
Fatal
(
"Wrong metric. For Multi-class with OVA, you should use multi_loglossova metric."
);
}
if
(
objective_type
==
std
::
string
(
"multiclass"
)
&&
metric_type
==
std
::
string
(
"multi_loglossova"
))
{
Log
::
Fatal
(
"Wrong metric. For Multi-class with softmax, you should use multi_logloss metric."
);
}
}
}
...
...
src/metric/metric.cpp
View file @
841a8987
...
...
@@ -29,7 +29,9 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
}
else
if
(
type
==
std
::
string
(
"map"
))
{
return
new
MapMetric
(
config
);
}
else
if
(
type
==
std
::
string
(
"multi_logloss"
))
{
return
new
MultiLoglossMetric
(
config
);
return
new
MultiSoftmaxLoglossMetric
(
config
);
}
else
if
(
type
==
std
::
string
(
"multi_loglossova"
))
{
return
new
MultiOVALoglossMetric
(
config
);
}
else
if
(
type
==
std
::
string
(
"multi_error"
))
{
return
new
MultiErrorMetric
(
config
);
}
...
...
src/metric/multiclass_metric.hpp
View file @
841a8987
...
...
@@ -79,8 +79,6 @@ public:
}
private:
/*! \brief Output frequency */
int
output_freq_
;
/*! \brief Number of data */
data_size_t
num_data_
;
/*! \brief Number of classes */
...
...
@@ -116,9 +114,9 @@ public:
};
/*! \brief Logloss for multiclass task */
class
MultiLoglossMetric
:
public
MulticlassMetric
<
MultiLoglossMetric
>
{
class
Multi
Softmax
LoglossMetric
:
public
MulticlassMetric
<
Multi
Softmax
LoglossMetric
>
{
public:
explicit
MultiLoglossMetric
(
const
MetricConfig
&
config
)
:
MulticlassMetric
<
MultiLoglossMetric
>
(
config
)
{}
explicit
Multi
Softmax
LoglossMetric
(
const
MetricConfig
&
config
)
:
MulticlassMetric
<
Multi
Softmax
LoglossMetric
>
(
config
)
{}
inline
static
double
LossOnPoint
(
float
label
,
std
::
vector
<
double
>&
score
)
{
size_t
k
=
static_cast
<
size_t
>
(
label
);
...
...
@@ -135,5 +133,84 @@ public:
}
};
class
MultiOVALoglossMetric
:
public
Metric
{
public:
explicit
MultiOVALoglossMetric
(
const
MetricConfig
&
config
)
{
num_class_
=
config
.
num_class
;
sigmoid_
=
config
.
sigmoid
;
}
virtual
~
MultiOVALoglossMetric
()
{
}
void
Init
(
const
Metadata
&
metadata
,
data_size_t
num_data
)
override
{
name_
.
emplace_back
(
"multi_loglossova"
);
num_data_
=
num_data
;
// get label
label_
=
metadata
.
label
();
// get weights
weights_
=
metadata
.
weights
();
if
(
weights_
==
nullptr
)
{
sum_weights_
=
static_cast
<
double
>
(
num_data_
);
}
else
{
sum_weights_
=
0.0
f
;
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
sum_weights_
+=
weights_
[
i
];
}
}
}
const
std
::
vector
<
std
::
string
>&
GetName
()
const
override
{
return
name_
;
}
double
factor_to_bigger_better
()
const
override
{
return
-
1.0
f
;
}
std
::
vector
<
double
>
Eval
(
const
double
*
score
)
const
override
{
double
sum_loss
=
0.0
;
if
(
weights_
==
nullptr
)
{
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
std
::
vector
<
double
>
rec
(
num_class_
);
size_t
idx
=
static_cast
<
size_t
>
(
num_data_
)
*
static_cast
<
int
>
(
label_
[
i
])
+
i
;
double
prob
=
1.0
f
/
(
1.0
f
+
std
::
exp
(
-
sigmoid_
*
score
[
idx
]));
if
(
prob
<
kEpsilon
)
{
prob
=
kEpsilon
;
}
// add loss
sum_loss
+=
-
std
::
log
(
prob
);
}
}
else
{
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
size_t
idx
=
static_cast
<
size_t
>
(
num_data_
)
*
static_cast
<
int
>
(
label_
[
i
])
+
i
;
double
prob
=
1.0
f
/
(
1.0
f
+
std
::
exp
(
-
sigmoid_
*
score
[
idx
]));
if
(
prob
<
kEpsilon
)
{
prob
=
kEpsilon
;
}
// add loss
sum_loss
+=
-
std
::
log
(
prob
)
*
weights_
[
i
];
}
}
double
loss
=
sum_loss
/
sum_weights_
;
return
std
::
vector
<
double
>
(
1
,
loss
);
}
private:
/*! \brief Number of data */
data_size_t
num_data_
;
/*! \brief Number of classes */
int
num_class_
;
/*! \brief Pointer of label */
const
float
*
label_
;
/*! \brief Pointer of weighs */
const
float
*
weights_
;
/*! \brief Sum weights */
double
sum_weights_
;
/*! \brief Name of this test set */
std
::
vector
<
std
::
string
>
name_
;
double
sigmoid_
;
};
}
// namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
src/objective/binary_objective.hpp
View file @
841a8987
...
...
@@ -12,15 +12,21 @@ namespace LightGBM {
*/
class
BinaryLogloss
:
public
ObjectiveFunction
{
public:
explicit
BinaryLogloss
(
const
ObjectiveConfig
&
config
)
{
explicit
BinaryLogloss
(
const
ObjectiveConfig
&
config
,
std
::
function
<
bool
(
float
)
>
is_pos
=
nullptr
)
{
is_unbalance_
=
config
.
is_unbalance
;
sigmoid_
=
static_cast
<
double
>
(
config
.
sigmoid
);
if
(
sigmoid_
<=
0.0
)
{
Log
::
Fatal
(
"Sigmoid parameter %f should be greater than zero"
,
sigmoid_
);
}
scale_pos_weight_
=
static_cast
<
double
>
(
config
.
scale_pos_weight
);
is_pos_
=
is_pos
;
if
(
is_pos_
==
nullptr
)
{
is_pos_
=
[](
float
label
)
{
return
label
>
0
;
};
}
}
~
BinaryLogloss
()
{}
void
Init
(
const
Metadata
&
metadata
,
data_size_t
num_data
)
override
{
num_data_
=
num_data
;
label_
=
metadata
.
label
();
...
...
@@ -30,7 +36,7 @@ public:
// count for positive and negative samples
#pragma omp parallel for schedule(static) reduction(+:cnt_positive, cnt_negative)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
if
(
label_
[
i
]
>
0
)
{
if
(
is_pos_
(
label_
[
i
]
)
)
{
++
cnt_positive
;
}
else
{
++
cnt_negative
;
...
...
@@ -61,7 +67,7 @@ public:
#pragma omp parallel for schedule(static)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
// get label and label weights
const
int
is_pos
=
label_
[
i
]
>
0
;
const
int
is_pos
=
is_pos_
(
label_
[
i
]
)
;
const
int
label
=
label_val_
[
is_pos
];
const
double
label_weight
=
label_weights_
[
is_pos
];
// calculate gradients and hessians
...
...
@@ -74,7 +80,7 @@ public:
#pragma omp parallel for schedule(static)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
// get label and label weights
const
int
is_pos
=
label_
[
i
]
>
0
;
const
int
is_pos
=
is_pos_
(
label_
[
i
]
)
;
const
int
label
=
label_val_
[
is_pos
];
const
double
label_weight
=
label_weights_
[
is_pos
];
// calculate gradients and hessians
...
...
@@ -106,6 +112,7 @@ private:
/*! \brief Weights for data */
const
float
*
weights_
;
double
scale_pos_weight_
;
std
::
function
<
bool
(
float
)
>
is_pos_
;
};
}
// namespace LightGBM
...
...
src/objective/multiclass_objective.hpp
View file @
841a8987
...
...
@@ -5,19 +5,22 @@
#include <cstring>
#include <cmath>
#include <vector>
#include "binary_objective.hpp"
namespace
LightGBM
{
/*!
* \brief Objective function for multiclass classification
* \brief Objective function for multiclass classification
, use softmax as objective functions
*/
class
Multiclass
Logloss
:
public
ObjectiveFunction
{
class
Multiclass
Softmax
:
public
ObjectiveFunction
{
public:
explicit
Multiclass
Logloss
(
const
ObjectiveConfig
&
config
)
{
explicit
Multiclass
Softmax
(
const
ObjectiveConfig
&
config
)
{
num_class_
=
config
.
num_class
;
is_unbalance_
=
config
.
is_unbalance
;
}
~
MulticlassLogloss
()
{
~
MulticlassSoftmax
()
{
}
void
Init
(
const
Metadata
&
metadata
,
data_size_t
num_data
)
override
{
...
...
@@ -32,18 +35,6 @@ public:
Log
::
Fatal
(
"Label must be in [0, %d), but found %d in label"
,
num_class_
,
label_int_
[
i
]);
}
}
label_pos_weights_
=
std
::
vector
<
float
>
(
num_class_
,
1
);
if
(
is_unbalance_
)
{
std
::
vector
<
int
>
cnts
(
num_class_
,
0
);
for
(
int
i
=
0
;
i
<
num_data_
;
++
i
)
{
++
cnts
[
label_int_
[
i
]];
}
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
int
cnt_cur
=
cnts
[
i
];
int
cnt_other
=
(
num_data_
-
cnts
[
i
]);
label_pos_weights_
[
i
]
=
static_cast
<
float
>
(
cnt_other
)
/
cnt_cur
;
}
}
}
void
GetGradients
(
const
double
*
score
,
score_t
*
gradients
,
score_t
*
hessians
)
const
override
{
...
...
@@ -52,7 +43,7 @@ public:
#pragma omp parallel for schedule(static) private(rec)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
rec
.
resize
(
num_class_
);
for
(
int
k
=
0
;
k
<
num_class_
;
++
k
){
for
(
int
k
=
0
;
k
<
num_class_
;
++
k
)
{
size_t
idx
=
static_cast
<
size_t
>
(
num_data_
)
*
k
+
i
;
rec
[
k
]
=
static_cast
<
double
>
(
score
[
idx
]);
}
...
...
@@ -61,12 +52,11 @@ public:
auto
p
=
rec
[
k
];
size_t
idx
=
static_cast
<
size_t
>
(
num_data_
)
*
k
+
i
;
if
(
label_int_
[
i
]
==
k
)
{
gradients
[
idx
]
=
static_cast
<
score_t
>
(
p
-
1.0
f
)
*
label_pos_weights_
[
k
];
hessians
[
idx
]
=
static_cast
<
score_t
>
(
p
*
(
1.0
f
-
p
))
*
label_pos_weights_
[
k
];
gradients
[
idx
]
=
static_cast
<
score_t
>
(
p
-
1.0
f
);
}
else
{
gradients
[
idx
]
=
static_cast
<
score_t
>
(
p
);
hessians
[
idx
]
=
static_cast
<
score_t
>
(
p
*
(
1.0
f
-
p
));
}
hessians
[
idx
]
=
static_cast
<
score_t
>
(
p
*
(
1.0
f
-
p
));
}
}
}
else
{
...
...
@@ -74,7 +64,7 @@ public:
#pragma omp parallel for schedule(static) private(rec)
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
rec
.
resize
(
num_class_
);
for
(
int
k
=
0
;
k
<
num_class_
;
++
k
){
for
(
int
k
=
0
;
k
<
num_class_
;
++
k
)
{
size_t
idx
=
static_cast
<
size_t
>
(
num_data_
)
*
k
+
i
;
rec
[
k
]
=
static_cast
<
double
>
(
score
[
idx
]);
}
...
...
@@ -83,13 +73,11 @@ public:
auto
p
=
rec
[
k
];
size_t
idx
=
static_cast
<
size_t
>
(
num_data_
)
*
k
+
i
;
if
(
label_int_
[
i
]
==
k
)
{
gradients
[
idx
]
=
static_cast
<
score_t
>
((
p
-
1.0
f
)
*
weights_
[
i
])
*
label_pos_weights_
[
k
];
hessians
[
idx
]
=
static_cast
<
score_t
>
(
p
*
(
1.0
f
-
p
)
*
weights_
[
i
])
*
label_pos_weights_
[
k
];
gradients
[
idx
]
=
static_cast
<
score_t
>
((
p
-
1.0
f
)
*
weights_
[
i
]);
}
else
{
gradients
[
idx
]
=
static_cast
<
score_t
>
(
p
*
weights_
[
i
]);
hessians
[
idx
]
=
static_cast
<
score_t
>
(
p
*
(
1.0
f
-
p
)
*
weights_
[
i
]);
}
hessians
[
idx
]
=
static_cast
<
score_t
>
(
p
*
(
1.0
f
-
p
)
*
weights_
[
i
]);
}
}
}
...
...
@@ -110,9 +98,49 @@ private:
std
::
vector
<
int
>
label_int_
;
/*! \brief Weights for data */
const
float
*
weights_
;
/*! \brief Weights for label */
std
::
vector
<
float
>
label_pos_weights_
;
bool
is_unbalance_
;
};
/*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class
MulticlassOVA
:
public
ObjectiveFunction
{
public:
explicit
MulticlassOVA
(
const
ObjectiveConfig
&
config
)
{
num_class_
=
config
.
num_class
;
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
binary_loss_
.
emplace_back
(
new
BinaryLogloss
(
config
,
[
i
](
float
label
)
{
return
static_cast
<
int
>
(
label
)
==
i
;
}));
}
}
~
MulticlassOVA
()
{
}
void
Init
(
const
Metadata
&
metadata
,
data_size_t
num_data
)
override
{
num_data_
=
num_data
;
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
binary_loss_
[
i
]
->
Init
(
metadata
,
num_data
);
}
}
void
GetGradients
(
const
double
*
score
,
score_t
*
gradients
,
score_t
*
hessians
)
const
override
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
int64_t
bias
=
static_cast
<
int64_t
>
(
num_data_
)
*
i
;
binary_loss_
[
i
]
->
GetGradients
(
score
+
bias
,
gradients
+
bias
,
hessians
+
bias
);
}
}
const
char
*
GetName
()
const
override
{
return
"multiclassova"
;
}
private:
/*! \brief Number of data */
data_size_t
num_data_
;
/*! \brief Number of classes */
int
num_class_
;
std
::
vector
<
std
::
unique_ptr
<
BinaryLogloss
>>
binary_loss_
;
};
}
// namespace LightGBM
...
...
src/objective/objective_function.cpp
View file @
841a8987
...
...
@@ -23,7 +23,9 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
}
else
if
(
type
==
std
::
string
(
"lambdarank"
))
{
return
new
LambdarankNDCG
(
config
);
}
else
if
(
type
==
std
::
string
(
"multiclass"
))
{
return
new
MulticlassLogloss
(
config
);
return
new
MulticlassSoftmax
(
config
);
}
else
if
(
type
==
std
::
string
(
"multiclassova"
))
{
return
new
MulticlassOVA
(
config
);
}
return
nullptr
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment