Unverified Commit 417ba192 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[docs] Clarify the fact that predict() on a file does not support saved...


[docs] Clarify the fact that predict() on a file does not support saved Datasets (fixes #4034) (#4545)

* documentation changes

* add list of supported formats to error message

* add unit tests

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* update per review comments

* make references consistent
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 11d7608f
...@@ -682,7 +682,8 @@ Booster <- R6::R6Class( ...@@ -682,7 +682,8 @@ Booster <- R6::R6Class(
#' @title Predict method for LightGBM model #' @title Predict method for LightGBM model
#' @description Predicted values based on class \code{lgb.Booster} #' @description Predicted values based on class \code{lgb.Booster}
#' @param object Object of class \code{lgb.Booster} #' @param object Object of class \code{lgb.Booster}
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object or
#' a character representing a path to a text file (CSV, TSV, or LibSVM)
#' @param start_iteration int or None, optional (default=None) #' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict. #' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration. #' If None or <= 0, starts from the first iteration.
......
...@@ -710,7 +710,9 @@ Dataset <- R6::R6Class( ...@@ -710,7 +710,9 @@ Dataset <- R6::R6Class(
#' @title Construct \code{lgb.Dataset} object #' @title Construct \code{lgb.Dataset} object
#' @description Construct \code{lgb.Dataset} object from dense matrix, sparse matrix #' @description Construct \code{lgb.Dataset} object from dense matrix, sparse matrix
#' or local file (that was created previously by saving an \code{lgb.Dataset}). #' or local file (that was created previously by saving an \code{lgb.Dataset}).
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object,
#' a character representing a path to a text file (CSV, TSV, or LibSVM),
#' or a character representing a path to a binary \code{lgb.Dataset} file
#' @param params a list of parameters. See #' @param params a list of parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{ #' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{
#' The "Dataset Parameters" section of the documentation} for a list of parameters #' The "Dataset Parameters" section of the documentation} for a list of parameters
...@@ -774,7 +776,9 @@ lgb.Dataset <- function(data, ...@@ -774,7 +776,9 @@ lgb.Dataset <- function(data,
#' @title Construct validation data #' @title Construct validation data
#' @description Construct validation data according to training data #' @description Construct validation data according to training data
#' @param dataset \code{lgb.Dataset} object, training data #' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename #' @param data a \code{matrix} object, a \code{dgCMatrix} object,
#' a character representing a path to a text file (CSV, TSV, or LibSVM),
#' or a character representing a path to a binary \code{Dataset} file
#' @param info a list of information of the \code{lgb.Dataset} object #' @param info a list of information of the \code{lgb.Dataset} object
#' @param ... other information to pass to \code{info}. #' @param ... other information to pass to \code{info}.
#' #'
......
...@@ -16,7 +16,9 @@ lgb.Dataset( ...@@ -16,7 +16,9 @@ lgb.Dataset(
) )
} }
\arguments{ \arguments{
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} \item{data}{a \code{matrix} object, a \code{dgCMatrix} object,
a character representing a path to a text file (CSV, TSV, or LibSVM),
or a character representing a path to a binary \code{lgb.Dataset} file}
\item{params}{a list of parameters. See \item{params}{a list of parameters. See
\href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{ \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#dataset-parameters}{
......
...@@ -9,7 +9,9 @@ lgb.Dataset.create.valid(dataset, data, info = list(), ...) ...@@ -9,7 +9,9 @@ lgb.Dataset.create.valid(dataset, data, info = list(), ...)
\arguments{ \arguments{
\item{dataset}{\code{lgb.Dataset} object, training data} \item{dataset}{\code{lgb.Dataset} object, training data}
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} \item{data}{a \code{matrix} object, a \code{dgCMatrix} object,
a character representing a path to a text file (CSV, TSV, or LibSVM),
or a character representing a path to a binary \code{Dataset} file}
\item{info}{a list of information of the \code{lgb.Dataset} object} \item{info}{a list of information of the \code{lgb.Dataset} object}
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
\arguments{ \arguments{
\item{object}{Object of class \code{lgb.Booster}} \item{object}{Object of class \code{lgb.Booster}}
\item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename} \item{data}{a \code{matrix} object, a \code{dgCMatrix} object or
a character representing a path to a text file (CSV, TSV, or LibSVM)}
\item{start_iteration}{int or None, optional (default=None) \item{start_iteration}{int or None, optional (default=None)
Start index of the iteration to predict. Start index of the iteration to predict.
......
context("Booster") context("Booster")
TOLERANCE <- 1e-6
test_that("Booster$finalize() should not fail", { test_that("Booster$finalize() should not fail", {
X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L)
y <- iris[["Sepal.Length"]] y <- iris[["Sepal.Length"]]
...@@ -419,6 +421,54 @@ test_that("Creating a Booster from a Dataset with an existing predictor should w ...@@ -419,6 +421,54 @@ test_that("Creating a Booster from a Dataset with an existing predictor should w
expect_equal(bst_from_ds$current_iter(), nrounds) expect_equal(bst_from_ds$current_iter(), nrounds)
}) })
test_that("Booster$eval() should work on a Dataset stored in a binary file", {
set.seed(708L)
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
bst <- lgb.train(
params = list(
objective = "regression"
, metric = "l2"
, num_leaves = 4L
)
, data = dtrain
, nrounds = 2L
)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(
dataset = dtrain
, data = test$data
, label = test$label
)
dtest$construct()
eval_in_mem <- bst$eval(
data = dtest
, name = "test"
)
test_file <- tempfile(pattern = "lgb.Dataset_")
lgb.Dataset.save(
dataset = dtest
, fname = test_file
)
rm(dtest)
eval_from_file <- bst$eval(
data = lgb.Dataset(
data = test_file
)$construct()
, name = "test"
)
expect_true(abs(eval_in_mem[[1L]][["value"]] - 0.1744423) < TOLERANCE)
expect_identical(eval_in_mem, eval_from_file)
})
test_that("Booster$rollback_one_iter() should work as expected", { test_that("Booster$rollback_one_iter() should work as expected", {
set.seed(708L) set.seed(708L)
data(agaricus.train, package = "lightgbm") data(agaricus.train, package = "lightgbm")
......
...@@ -33,7 +33,7 @@ Data Interface ...@@ -33,7 +33,7 @@ Data Interface
The LightGBM Python module can load data from: The LightGBM Python module can load data from:
- LibSVM (zero-based) / TSV / CSV / TXT format file - LibSVM (zero-based) / TSV / CSV format text file
- NumPy 2D array(s), pandas DataFrame, H2O DataTable's Frame, SciPy sparse matrix - NumPy 2D array(s), pandas DataFrame, H2O DataTable's Frame, SciPy sparse matrix
......
...@@ -744,7 +744,7 @@ class _InnerPredictor: ...@@ -744,7 +744,7 @@ class _InnerPredictor:
---------- ----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction. Data source for prediction.
When data type is string or pathlib.Path, it represents the path of txt file. When data type is string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
Start index of the iteration to predict. Start index of the iteration to predict.
num_iteration : int, optional (default=-1) num_iteration : int, optional (default=-1)
...@@ -1132,7 +1132,7 @@ class Dataset: ...@@ -1132,7 +1132,7 @@ class Dataset:
---------- ----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset. Data source of Dataset.
If string or pathlib.Path, it represents the path to txt file. If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data. Label of the data.
reference : Dataset or None, optional (default=None) reference : Dataset or None, optional (default=None)
...@@ -1776,7 +1776,7 @@ class Dataset: ...@@ -1776,7 +1776,7 @@ class Dataset:
---------- ----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays
Data source of Dataset. Data source of Dataset.
If string or pathlib.Path, it represents the path to txt file. If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
Label of the data. Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None) weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
...@@ -3414,7 +3414,7 @@ class Booster: ...@@ -3414,7 +3414,7 @@ class Booster:
---------- ----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction. Data source for prediction.
If string or pathlib.Path, it represents the path to txt file. If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
Start index of the iteration to predict. Start index of the iteration to predict.
If <= 0, starts from the first iteration. If <= 0, starts from the first iteration.
...@@ -3469,7 +3469,7 @@ class Booster: ...@@ -3469,7 +3469,7 @@ class Booster:
---------- ----------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for refit. Data source for refit.
If string or pathlib.Path, it represents the path to txt file. If string or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array or pandas Series / one-column DataFrame label : list, numpy 1-D array or pandas Series / one-column DataFrame
Label for refit. Label for refit.
decay_rate : float, optional (default=0.9) decay_rate : float, optional (default=0.9)
......
...@@ -236,7 +236,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features ...@@ -236,7 +236,7 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
int num_col = 0; int num_col = 0;
DataType type = GetDataType(filename, header, lines, &num_col); DataType type = GetDataType(filename, header, lines, &num_col);
if (type == DataType::INVALID) { if (type == DataType::INVALID) {
Log::Fatal("Unknown format of training data."); Log::Fatal("Unknown format of training data. Only CSV, TSV, and LibSVM (zero-based) formatted text files are supported.");
} }
std::unique_ptr<Parser> ret; std::unique_ptr<Parser> ret;
int output_label_index = -1; int output_label_index = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment