Unverified Commit f8602826 authored by david-cortes's avatar david-cortes Committed by GitHub
Browse files

[R-package] Keep row names in output from `predict` (#4977)



* keep row names in prediction output

* fix test

* fix test

* fix test

* more tests for row names

* redo badly solved merge conflict

* comments

* update docs

* missing closing parentheses

* explicitly list assertions in each test

* move assertions back into shared function
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent d811f3a0
......@@ -219,6 +219,15 @@ Predictor <- R6::R6Class(
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)
}
# Keep row names if possible
if (NROW(row.names(data)) && NROW(data) == NROW(preds)) {
if (is.null(dim(preds))) {
names(preds) <- row.names(data)
} else {
row.names(preds) <- row.names(data)
}
}
return(preds)
}
......
......@@ -2,6 +2,8 @@ VERBOSITY <- as.integer(
Sys.getenv("LIGHTGBM_TEST_VERBOSITY", "-1")
)
library(Matrix)
test_that("Predictor$finalize() should not fail", {
X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L)
y <- iris[["Sepal.Length"]]
......@@ -112,6 +114,115 @@ test_that("start_iteration works correctly", {
expect_equal(pred_leaf1, pred_leaf2)
})
.expect_has_row_names <- function(pred, X) {
if (is.vector(pred)) {
rnames <- names(pred)
} else {
rnames <- row.names(pred)
}
expect_false(is.null(rnames))
expect_true(is.vector(rnames))
expect_true(length(rnames) > 0L)
expect_equal(row.names(X), rnames)
}
.expect_doesnt_have_row_names <- function(pred) {
if (is.vector(pred)) {
expect_null(names(pred))
} else {
expect_null(row.names(pred))
}
}
.check_all_row_name_expectations <- function(bst, X) {
# dense matrix with row names
pred <- predict(bst, X)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, rawscore = TRUE)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, predleaf = TRUE)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, predcontrib = TRUE)
.expect_has_row_names(pred, X)
# dense matrix without row names
Xcopy <- X
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy)
.expect_doesnt_have_row_names(pred)
# sparse matrix with row names
Xcsc <- as(X, "CsparseMatrix")
pred <- predict(bst, Xcsc)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, rawscore = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predleaf = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predcontrib = TRUE)
.expect_has_row_names(pred, Xcsc)
# sparse matrix without row names
Xcopy <- Xcsc
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy)
.expect_doesnt_have_row_names(pred)
}
test_that("predict() keeps row names from data (regression)", {
data("mtcars")
X <- as.matrix(mtcars[, -1L])
y <- as.numeric(mtcars[, 1L])
dtrain <- lgb.Dataset(
X
, label = y
, params = list(
max_bins = 5L
, min_data_in_bin = 1L
)
)
bst <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = VERBOSITY
, params = list(min_data_in_leaf = 1L)
)
.check_all_row_name_expectations(bst, X)
})
test_that("predict() keeps row names from data (binary classification)", {
data(agaricus.train, package = "lightgbm")
X <- as.matrix(agaricus.train$data)
y <- agaricus.train$label
row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "")
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "binary"
, nrounds = 5L
, verbose = VERBOSITY
)
.check_all_row_name_expectations(bst, X)
})
test_that("predict() keeps row names from data (multi-class classification)", {
data(iris)
y <- as.numeric(iris$Species) - 1.0
X <- as.matrix(iris[, names(iris) != "Species"])
row.names(X) <- paste("rname", seq(1L, nrow(X)), sep = "")
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
bst <- lgb.train(
data = dtrain
, obj = "multiclass"
, params = list(num_class = 3L)
, nrounds = 5L
, verbose = VERBOSITY
)
.check_all_row_name_expectations(bst, X)
})
test_that("predictions for regression and binary classification are returned as vectors", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment