Commit ace9c99c authored by Maximilian Eber's avatar Maximilian Eber Committed by Laurae
Browse files

[R] Fix multiclass demo (#1940)

* Fix multiclass custom objective demo

* Use option not to boost from average instead of setting init score explicitly

* Reference #1846 when turning off boost_from_average

* Add trailing whitespace
parent e6a32c88
...@@ -8,18 +8,21 @@ data(iris) ...@@ -8,18 +8,21 @@ data(iris)
# For instance: 0, 1, 2, 3, 4, 5... # For instance: 0, 1, 2, 3, 4, 5...
iris$Species <- as.numeric(as.factor(iris$Species)) - 1 iris$Species <- as.numeric(as.factor(iris$Species)) - 1
# We cut the data set into 80% train and 20% validation # Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
train <- as.matrix(iris[c(1:20, 51:80, 101:140), ])
# The 10 last samples of each class are for validation # The 10 last samples of each class are for validation
train <- as.matrix(iris[c(1:40, 51:90, 101:140), ])
test <- as.matrix(iris[c(41:50, 91:100, 141:150), ]) test <- as.matrix(iris[c(41:50, 91:100, 141:150), ])
dtrain <- lgb.Dataset(data = train[, 1:4], label = train[, 5]) dtrain <- lgb.Dataset(data = train[, 1:4], label = train[, 5])
dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1:4], label = test[, 5]) dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1:4], label = test[, 5])
valids <- list(test = dtest) valids <- list(train = dtrain, test = dtest)
# Method 1 of training with built-in multiclass objective # Method 1 of training with built-in multiclass objective
# Note: need to turn off boost from average to match custom objective
# (https://github.com/Microsoft/LightGBM/issues/1846)
model_builtin <- lgb.train(list(), model_builtin <- lgb.train(list(),
dtrain, dtrain,
boost_from_average = FALSE,
100, 100,
valids, valids,
min_data = 1, min_data = 1,
...@@ -29,7 +32,8 @@ model_builtin <- lgb.train(list(), ...@@ -29,7 +32,8 @@ model_builtin <- lgb.train(list(),
metric = "multi_logloss", metric = "multi_logloss",
num_class = 3) num_class = 3)
preds_builtin <- predict(model_builtin, test[, 1:4], rawscore = TRUE) preds_builtin <- predict(model_builtin, test[, 1:4], rawscore = TRUE, reshape = TRUE)
probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin))
# Method 2 of training with custom objective function # Method 2 of training with custom objective function
...@@ -64,7 +68,6 @@ custom_multiclass_metric = function(preds, dtrain) { ...@@ -64,7 +68,6 @@ custom_multiclass_metric = function(preds, dtrain) {
return(list(name = "error", return(list(name = "error",
value = -mean(log(prob[cbind(1:length(labels), labels + 1)])), value = -mean(log(prob[cbind(1:length(labels), labels + 1)])),
higher_better = FALSE)) higher_better = FALSE))
} }
model_custom <- lgb.train(list(), model_custom <- lgb.train(list(),
...@@ -78,8 +81,10 @@ model_custom <- lgb.train(list(), ...@@ -78,8 +81,10 @@ model_custom <- lgb.train(list(),
eval = custom_multiclass_metric, eval = custom_multiclass_metric,
num_class = 3) num_class = 3)
preds_custom <- predict(model_custom, test[, 1:4], rawscore = TRUE) preds_custom <- predict(model_custom, test[, 1:4], rawscore = TRUE, reshape = TRUE)
probs_custom <- exp(preds_custom) / rowSums(exp(preds_custom))
# compare predictions # compare predictions
identical(preds_builtin, preds_custom) stopifnot(identical(probs_builtin, probs_custom))
stopifnot(identical(preds_builtin, preds_custom))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment