Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
6522f538
Commit
6522f538
authored
Feb 05, 2018
by
Jan Tilly
Committed by
Guolin Ke
Feb 05, 2018
Browse files
Permit use of custom objective function in multiclass problem. (#1229)
parent
40602f9e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
2 deletions
+89
-2
R-package/demo/multiclass_custom_objective.R
R-package/demo/multiclass_custom_objective.R
+85
-0
src/io/config.cpp
src/io/config.cpp
+4
-2
No files found.
R-package/demo/multiclass_custom_objective.R
0 → 100644
View file @
6522f538
require
(
lightgbm
)
# We load the default iris dataset shipped with R
data
(
iris
)
# We must convert factors to numeric
# They must be starting from number 0 to use multiclass
# For instance: 0, 1, 2, 3, 4, 5...
iris
$
Species
<-
as.numeric
(
as.factor
(
iris
$
Species
))
-
1
# We cut the data set into 80% train and 20% validation
# The 10 last samples of each class are for validation
train
<-
as.matrix
(
iris
[
c
(
1
:
40
,
51
:
90
,
101
:
140
),
])
test
<-
as.matrix
(
iris
[
c
(
41
:
50
,
91
:
100
,
141
:
150
),
])
dtrain
<-
lgb.Dataset
(
data
=
train
[,
1
:
4
],
label
=
train
[,
5
])
dtest
<-
lgb.Dataset.create.valid
(
dtrain
,
data
=
test
[,
1
:
4
],
label
=
test
[,
5
])
valids
<-
list
(
test
=
dtest
)
# Method 1 of training with built-in multiclass objective
model_builtin
<-
lgb.train
(
list
(),
dtrain
,
100
,
valids
,
min_data
=
1
,
learning_rate
=
1
,
early_stopping_rounds
=
10
,
objective
=
"multiclass"
,
metric
=
"multi_logloss"
,
num_class
=
3
)
preds_builtin
<-
predict
(
model_builtin
,
test
[,
1
:
4
],
rawscore
=
TRUE
)
# Method 2 of training with custom objective function
# User defined objective function, given prediction, return gradient and second order gradient
custom_multiclass_obj
=
function
(
preds
,
dtrain
)
{
labels
=
getinfo
(
dtrain
,
"label"
)
# preds is a matrix with rows corresponding to samples and colums corresponding to choices
preds
=
matrix
(
preds
,
nrow
=
length
(
labels
))
# to prevent overflow, normalize preds by row
preds
=
preds
-
apply
(
preds
,
1
,
max
)
prob
=
exp
(
preds
)
/
rowSums
(
exp
(
preds
))
# compute gradient
grad
=
prob
grad
[
cbind
(
1
:
length
(
labels
),
labels
+
1
)]
=
grad
[
cbind
(
1
:
length
(
labels
),
labels
+
1
)]
-
1
# compute hessian (approximation)
hess
=
2
*
prob
*
(
1
-
prob
)
return
(
list
(
grad
=
grad
,
hess
=
hess
))
}
# define custom metric
custom_multiclass_metric
=
function
(
preds
,
dtrain
)
{
labels
=
getinfo
(
dtrain
,
"label"
)
preds
=
matrix
(
preds
,
nrow
=
length
(
labels
))
preds
=
preds
-
apply
(
preds
,
1
,
max
)
prob
=
exp
(
preds
)
/
rowSums
(
exp
(
preds
))
return
(
list
(
name
=
"error"
,
value
=
-
mean
(
log
(
prob
[
cbind
(
1
:
length
(
labels
),
labels
+
1
)])),
higher_better
=
FALSE
))
}
model_custom
<-
lgb.train
(
list
(),
dtrain
,
100
,
valids
,
min_data
=
1
,
learning_rate
=
1
,
early_stopping_rounds
=
10
,
objective
=
custom_multiclass_obj
,
eval
=
custom_multiclass_metric
,
num_class
=
3
)
preds_custom
<-
predict
(
model_custom
,
test
[,
1
:
4
],
rawscore
=
TRUE
)
# compare predictions
identical
(
preds_builtin
,
preds_custom
)
src/io/config.cpp
View file @
6522f538
...
@@ -194,9 +194,11 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
...
@@ -194,9 +194,11 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
void
OverallConfig
::
CheckParamConflict
()
{
void
OverallConfig
::
CheckParamConflict
()
{
// check if objective_type, metric_type, and num_class match
// check if objective_type, metric_type, and num_class match
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
)
||
objective_type
==
std
::
string
(
"multiclassova"
));
int
num_class_check
=
boosting_config
.
num_class
;
int
num_class_check
=
boosting_config
.
num_class
;
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
)
||
objective_type
==
std
::
string
(
"multiclassova"
)
||
(
objective_type
==
std
::
string
(
"none"
)
&&
num_class_check
>
1
));
if
(
objective_type_multiclass
)
{
if
(
objective_type_multiclass
)
{
if
(
num_class_check
<=
1
)
{
if
(
num_class_check
<=
1
)
{
Log
::
Fatal
(
"Number of classes should be specified and greater than 1 for multiclass training"
);
Log
::
Fatal
(
"Number of classes should be specified and greater than 1 for multiclass training"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment