lgb.Dataset.R 24 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
Dataset <- R6Class(
  "lgb.Dataset",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
5
6
  public = list(
    finalize = function() {
      if (!lgb.is.null.handle(private$handle)) {
7
        cat("free dataset handle\n")
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
        lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
        private$handle <- NULL
      }
    },
    initialize = function(data,
13
14
15
                          params              = list(),
                          reference           = NULL,
                          colnames            = NULL,
Guolin Ke's avatar
Guolin Ke committed
16
                          categorical_feature = NULL,
17
18
19
20
                          predictor           = NULL,
                          free_raw_data       = TRUE,
                          used_indices        = NULL,
                          info                = list(),
Guolin Ke's avatar
Guolin Ke committed
21
                          ...) {
22
23
24
25
26
      additional_params <- list(...)
      INFO_KEYS <- c('label', 'weight', 'init_score', 'group')
      for (key in names(additional_params)) {
        if (key %in% INFO_KEYS) {
          info[[key]] <- additional_params[[key]]
Guolin Ke's avatar
Guolin Ke committed
27
        } else {
28
          params[[key]] <- additional_params[[key]]
Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
        }
      }
      if (!is.null(reference)) {
        if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
33
          stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
Guolin Ke's avatar
Guolin Ke committed
34
35
36
37
        }
      }
      if (!is.null(predictor)) {
        if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
38
          stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
Guolin Ke's avatar
Guolin Ke committed
39
40
        }
      }
41
42
      private$raw_data  <- data
      private$params    <- params
Guolin Ke's avatar
Guolin Ke committed
43
      private$reference <- reference
44
45
      private$colnames  <- colnames

Guolin Ke's avatar
Guolin Ke committed
46
      private$categorical_feature <- categorical_feature
47
48
49
50
      private$predictor           <- predictor
      private$free_raw_data       <- free_raw_data
      private$used_indices        <- used_indices
      private$info                <- info
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    },
    create_valid = function(data, info = list(),  ...) {
      ret <- Dataset$new(
        data,
        private$params,
        self,
        private$colnames,
        private$categorical_feature,
        private$predictor,
        private$free_raw_data,
        NULL,
        info,
        ...
      )
65
      ret
Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
70
71
72
    },
    construct = function() {
      if (!lgb.is.null.handle(private$handle)) {
        return(self)
      }
      # Get feature names
      cnames <- NULL
73
      if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
Guolin Ke's avatar
Guolin Ke committed
74
75
76
        cnames <- colnames(private$raw_data)
      }
      # set feature names if not exist
77
      if (is.null(private$colnames) && !is.null(cnames)) {
Guolin Ke's avatar
Guolin Ke committed
78
79
80
81
82
83
        private$colnames <- as.character(cnames)
      }
      # Get categorical feature index
      if (!is.null(private$categorical_feature)) {
        fname_dict <- list()
        if (!is.null(private$colnames)) {
84
85
          fname_dict <- `names<-`(
              list((seq_along(private$colnames) - 1)),
Guolin Ke's avatar
Guolin Ke committed
86
              private$colnames
87
            )
Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
92
93
        }
        cate_indices <- list()
        for (key in private$categorical_feature) {
          if (is.character(key)) {
            idx <- fname_dict[[key]]
            if (is.null(idx)) {
94
              stop("lgb.self.get.handle: cannot find feature name ", sQuote(key))
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
99
100
101
102
103
104
105
106
            }
            cate_indices <- c(cate_indices, idx)
          } else {
            # one-based indices to zero-based
            idx <- as.integer(key - 1)
            cate_indices <- c(cate_indices, idx)
          }
        }
        private$params$categorical_feature <- cate_indices
      }
      # Check has header or not
      has_header <- FALSE
107
      if (!is.null(private$params$has_header) ||
Guolin Ke's avatar
Guolin Ke committed
108
109
          !is.null(private$params$header)) {
        if (tolower(as.character(private$params$has_header)) == "true"
110
            ||
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            tolower(as.character(private$params$header)) == "true") {
          has_header <- TRUE
        }
      }
      # Generate parameter str
      params_str <- lgb.params2str(private$params)
      # get handle of reference dataset
      ref_handle <- NULL
      if (!is.null(private$reference)) {
        ref_handle <- private$reference$.__enclos_env__$private$get_handle()
      }
      handle <- lgb.new.handle()
      # not subset
      if (is.null(private$used_indices)) {
125
126
        if (is.character(private$raw_data)) {
          handle <- lgb.call(
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
132
133
              "LGBM_DatasetCreateFromFile_R",
              ret = handle,
              lgb.c_str(private$raw_data),
              params_str,
              ref_handle
            )
        } else if (is.matrix(private$raw_data)) {
134
          handle <- lgb.call(
Guolin Ke's avatar
Guolin Ke committed
135
136
137
138
139
140
141
142
              "LGBM_DatasetCreateFromMat_R",
              ret = handle,
              private$raw_data,
              nrow(private$raw_data),
              ncol(private$raw_data),
              params_str,
              ref_handle
            )
143
        } else if (is(private$raw_data, "dgCMatrix")) {
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
148
149
150
151
152
153
154
155
156
          handle <- lgb.call(
            "LGBM_DatasetCreateFromCSC_R",
            ret = handle,
            private$raw_data@p,
            private$raw_data@i,
            private$raw_data@x,
            length(private$raw_data@p),
            length(private$raw_data@x),
            nrow(private$raw_data),
            params_str,
            ref_handle
          )
        } else {
157
158
159
          stop(
            "lgb.Dataset.construct: does not support constructing from ", sQuote(class(private$raw_data))
          )
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
        }
      } else {
        # construct subset
        if (is.null(private$reference)) {
164
          stop("lgb.Dataset.construct: reference cannot be NULL for constructing data subset")
Guolin Ke's avatar
Guolin Ke committed
165
        }
166
        handle <- lgb.call(
Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
171
172
173
174
175
176
177
            "LGBM_DatasetGetSubset_R",
            ret = handle,
            ref_handle,
            private$used_indices,
            length(private$used_indices),
            params_str
          )
      }
      class(handle) <- "lgb.Dataset.handle"
      private$handle <- handle
      # set feature names
178
179
      if (!is.null(private$colnames)) { self$set_colnames(private$colnames) }

Guolin Ke's avatar
Guolin Ke committed
180
      # load init score
181
      if (!is.null(private$predictor) &&
Guolin Ke's avatar
Guolin Ke committed
182
          is.null(private$used_indices)) {
183
184
        init_score <- private$predictor$predict(private$raw_data, rawscore = TRUE, reshape = TRUE)
        # do not need to transpose, for is col_marjor
Guolin Ke's avatar
Guolin Ke committed
185
186
187
        init_score <- as.vector(init_score)
        private$info$init_score <- init_score
      }
188
      if (isTRUE(private$free_raw_data)) { private$raw_data <- NULL }
Guolin Ke's avatar
Guolin Ke committed
189
190
      if (length(private$info) > 0) {
        # set infos
191
        for (i in seq_along(private$info)) {
Guolin Ke's avatar
Guolin Ke committed
192
193
194
195
196
197
198
          p <- private$info[i]
          self$setinfo(names(p), p[[1]])
        }
      }
      if (is.null(self$getinfo("label"))) {
        stop("lgb.Dataset.construct: label should be set")
      }
199
      self
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
    },
    dim = function() {
      if (!lgb.is.null.handle(private$handle)) {
        num_row <- as.integer(0)
        num_col <- as.integer(0)
205
206
207

        c(
          lgb.call("LGBM_DatasetGetNumData_R",    ret = num_row, private$handle),
Guolin Ke's avatar
Guolin Ke committed
208
          lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle)
209
210
211
        )
      } else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
        dim(private$raw_data)
Guolin Ke's avatar
Guolin Ke committed
212
213
      } else {
        stop(
214
          "dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly"
Guolin Ke's avatar
Guolin Ke committed
215
216
217
218
219
        )
      }
    },
    get_colnames = function() {
      if (!lgb.is.null.handle(private$handle)) {
220
221
222
223
224
        cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R", private$handle)
        private$colnames <- as.character(base::strsplit(cnames, "\t")[[1]])
        private$colnames
      } else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
        colnames(private$raw_data)
Guolin Ke's avatar
Guolin Ke committed
225
226
      } else {
        stop(
227
          "dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly"
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
        )
      }
    },
    set_colnames = function(colnames) {
232
      if (is.null(colnames)) { return(self) }
Guolin Ke's avatar
Guolin Ke committed
233
      colnames <- as.character(colnames)
234
      if (length(colnames) == 0) { return(self) }
Guolin Ke's avatar
Guolin Ke committed
235
236
237
238
239
240
241
242
      private$colnames <- colnames
      if (!lgb.is.null.handle(private$handle)) {
        merged_name <- paste0(as.list(private$colnames), collapse = "\t")
        lgb.call("LGBM_DatasetSetFeatureNames_R",
                 ret = NULL,
                 private$handle,
                 lgb.c_str(merged_name))
      }
243
      self
Guolin Ke's avatar
Guolin Ke committed
244
245
    },
    getinfo = function(name) {
246
247
248
249
      INFONAMES <- c("label", "weight", "init_score", "group")
      if (!is.character(name) ||
          length(name) != 1   ||
          !name %in% INFONAMES) {
Guolin Ke's avatar
Guolin Ke committed
250
        stop(
251
          "getinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", ")
Guolin Ke's avatar
Guolin Ke committed
252
253
        )
      }
254
      if (is.null(private$info[[name]]) && !lgb.is.null.handle(private$handle)) {
Guolin Ke's avatar
Guolin Ke committed
255
        info_len <- as.integer(0)
256
257
258
259
        info_len <- lgb.call("LGBM_DatasetGetFieldSize_R",
                             ret = info_len,
                             private$handle,
                             lgb.c_str(name))
Guolin Ke's avatar
Guolin Ke committed
260
261
        if (info_len > 0) {
          ret <- NULL
262
263
264
265
266
          ret <- if (name == "group") { integer(info_len) } else { rep(0.0, info_len) }
          ret <- lgb.call("LGBM_DatasetGetField_R",
                          ret = ret,
                          private$handle,
                          lgb.c_str(name))
Guolin Ke's avatar
Guolin Ke committed
267
268
269
          private$info[[name]] <- ret
        }
      }
270
      private$info[[name]]
Guolin Ke's avatar
Guolin Ke committed
271
272
    },
    setinfo = function(name, info) {
273
274
275
276
      INFONAMES <- c("label", "weight", "init_score", "group")
      if (!is.character(name) ||
          length(name) != 1   ||
          !name %in% INFONAMES) {
Guolin Ke's avatar
Guolin Ke committed
277
        stop(
278
          "setinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", ")
Guolin Ke's avatar
Guolin Ke committed
279
280
        )
      }
281
      info <- if (name == "group") { as.integer(info) } else { as.numeric(info) }
Guolin Ke's avatar
Guolin Ke committed
282
      private$info[[name]] <- info
283
      if (!lgb.is.null.handle(private$handle) && !is.null(info)) {
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
289
290
291
292
293
294
        if (length(info) > 0) {
          lgb.call(
            "LGBM_DatasetSetField_R",
            ret = NULL,
            private$handle,
            lgb.c_str(name),
            info,
            length(info)
          )
        }
      }
295
      self
Guolin Ke's avatar
Guolin Ke committed
296
297
    },
    slice = function(idxset, ...) {
298
      Dataset$new(
Guolin Ke's avatar
Guolin Ke committed
299
300
301
302
303
304
305
306
307
308
309
310
        NULL,
        private$params,
        self,
        private$colnames,
        private$categorical_feature,
        private$predictor,
        private$free_raw_data,
        idxset,
        NULL,
        ...
      )
    },
311
    update_params = function(params) {
Guolin Ke's avatar
Guolin Ke committed
312
      private$params <- modifyList(private$params, params)
313
      self
Guolin Ke's avatar
Guolin Ke committed
314
315
    },
    set_categorical_feature = function(categorical_feature) {
316
      if (identical(private$categorical_feature, categorical_feature)) { return(self) }
Guolin Ke's avatar
Guolin Ke committed
317
318
      if (is.null(private$raw_data)) {
        stop(
319
320
          "set_categorical_feature: cannot set categorical feature after freeing raw data,
          please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset"
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
        )
      }
      private$categorical_feature <- categorical_feature
      self$finalize()
325
      self
Guolin Ke's avatar
Guolin Ke committed
326
327
328
329
330
    },
    set_reference = function(reference) {
      self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature)
      self$set_colnames(reference$get_colnames())
      private$set_predictor(reference$.__enclos_env__$private$predictor)
331
      if (identical(private$reference, reference)) { return(self) }
Guolin Ke's avatar
Guolin Ke committed
332
333
      if (is.null(private$raw_data)) {
        stop(
334
335
          "set_reference: cannot set reference after freeing raw data,
          please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset"
Guolin Ke's avatar
Guolin Ke committed
336
337
338
339
        )
      }
      if (!is.null(reference)) {
        if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
340
          stop("set_reference: Can only use lgb.Dataset as a reference")
Guolin Ke's avatar
Guolin Ke committed
341
342
343
344
        }
      }
      private$reference <- reference
      self$finalize()
345
      self
Guolin Ke's avatar
Guolin Ke committed
346
347
348
349
350
351
352
    },
    save_binary = function(fname) {
      self$construct()
      lgb.call("LGBM_DatasetSaveBinary_R",
               ret = NULL,
               private$handle,
               lgb.c_str(fname))
353
      self
Guolin Ke's avatar
Guolin Ke committed
354
355
356
    }
  ),
  private = list(
357
358
359
360
361
    handle              = NULL,
    raw_data            = NULL,
    params              = list(),
    reference           = NULL,
    colnames            = NULL,
Guolin Ke's avatar
Guolin Ke committed
362
    categorical_feature = NULL,
363
364
365
366
367
368
369
    predictor           = NULL,
    free_raw_data       = TRUE,
    used_indices        = NULL,
    info                = NULL,
    get_handle          = function() {
      if (lgb.is.null.handle(private$handle)) { self$construct() }
      private$handle
Guolin Ke's avatar
Guolin Ke committed
370
371
    },
    set_predictor = function(predictor) {
372
      if (identical(private$predictor, predictor)) { return(self) }
Guolin Ke's avatar
Guolin Ke committed
373
374
375
      if (is.null(private$raw_data)) {
        stop(
          "set_predictor: cannot set predictor after free raw data,
376
          please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset"
Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
        )
      }
      if (!is.null(predictor)) {
        if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
381
          stop("set_predictor: Can only use lgb.Predictor as predictor")
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385
        }
      }
      private$predictor <- predictor
      self$finalize()
386
      self
Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    }
  )
)

#' Contruct lgb.Dataset object
#'
#' Contruct lgb.Dataset object from dense matrix, sparse matrix
#' or local file (that was created previously by saving an \code{lgb.Dataset}).
#'
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param params a list of parameters
#' @param reference reference dataset
#' @param colnames names of columns
#' @param categorical_feature categorical features
#' @param free_raw_data TRUE for need to free raw data after construct
#' @param info a list of information of the lgb.Dataset object
#' @param ... other information to pass to \code{info} or parameters pass to \code{params}
#' @return constructed dataset
#' @examples
406
407
408
409
410
411
412
413
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
#'   dtrain <- lgb.Dataset('lgb.Dataset.data')
#'   lgb.Dataset.construct(dtrain)
#' }
Guolin Ke's avatar
Guolin Ke committed
414
415
#' @export
lgb.Dataset <- function(data,
416
417
418
                        params              = list(),
                        reference           = NULL,
                        colnames            = NULL,
Guolin Ke's avatar
Guolin Ke committed
419
                        categorical_feature = NULL,
420
421
                        free_raw_data       = TRUE,
                        info                = list(),
Guolin Ke's avatar
Guolin Ke committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                        ...) {
  Dataset$new(
    data,
    params,
    reference,
    colnames,
    categorical_feature,
    NULL,
    free_raw_data,
    NULL,
    info,
    ...
  )
}


438
#' Contruct validation data
Guolin Ke's avatar
Guolin Ke committed
439
#'
440
#' Contruct validation data according to training data
Guolin Ke's avatar
Guolin Ke committed
441
442
443
444
445
446
447
#'
#' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param info a list of information of the lgb.Dataset object
#' @param ... other information to pass to \code{info}.
#' @return constructed dataset
#' @examples
448
449
450
451
452
453
454
455
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   data(agaricus.test, package='lightgbm')
#'   test <- agaricus.test
#'   dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' }
Guolin Ke's avatar
Guolin Ke committed
456
#' @export
457
458
459
lgb.Dataset.create.valid <- function(dataset, data, info = list(),  ...) {
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.create.valid: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
460
  }
461
462
  dataset$create_valid(data, info, ...)
}
Guolin Ke's avatar
Guolin Ke committed
463

464
#' Construct Dataset explicitly
Guolin Ke's avatar
Guolin Ke committed
465
466
467
#'
#' @param dataset Object of class \code{lgb.Dataset}
#' @examples
468
469
470
471
472
473
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.construct(dtrain)
#' }
Guolin Ke's avatar
Guolin Ke committed
474
475
#' @export
lgb.Dataset.construct <- function(dataset) {
476
477
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.construct: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
478
  }
479
  dataset$construct()
Guolin Ke's avatar
Guolin Ke committed
480
481
}

482
#' Dimensions of an lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
483
484
485
486
487
488
489
490
491
492
493
#'
#' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
#' @param x Object of class \code{lgb.Dataset}
#' @param ... other parameters
#' @return a vector of numbers of rows and of columns
#'
#' @details
#' Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
#' be directly used with an \code{lgb.Dataset} object.
#'
#' @examples
494
495
496
497
498
499
500
501
502
#' dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'
#'   stopifnot(nrow(dtrain) == nrow(train$data))
#'   stopifnot(ncol(dtrain) == ncol(train$data))
#'   stopifnot(all(dim(dtrain) == dim(train$data)))
#' }
Guolin Ke's avatar
Guolin Ke committed
503
504
505
#' @rdname dim
#' @export
dim.lgb.Dataset <- function(x, ...) {
506
507
  if (!lgb.is.Dataset(x)) {
    stop("dim.lgb.Dataset: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
508
  }
509
  x$dim()
Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
514
}

#' Handling of column names of \code{lgb.Dataset}
#'
#' Only column names are supported for \code{lgb.Dataset}, thus setting of
515
#' row names would have no effect and returned row names would be NULL.
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
521
522
523
524
525
#'
#' @param x object of class \code{lgb.Dataset}
#' @param value a list of two elements: the first one is ignored
#'        and the second one is column names
#'
#' @details
#' Generic \code{dimnames} methods are used by \code{colnames}.
#' Since row names are irrelevant, it is recommended to use \code{colnames} directly.
#'
#' @examples
526
527
528
529
530
531
532
533
534
535
#' dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.construct(dtrain)
#'   dimnames(dtrain)
#'   colnames(dtrain)
#'   colnames(dtrain) <- make.names(1:ncol(train$data))
#'   print(dtrain, verbose=TRUE)
#' }
Guolin Ke's avatar
Guolin Ke committed
536
537
538
#' @rdname dimnames.lgb.Dataset
#' @export
dimnames.lgb.Dataset <- function(x) {
539
540
  if (!lgb.is.Dataset(x)) {
    stop("dimnames.lgb.Dataset: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
541
  }
542
  list(NULL, x$get_colnames())
Guolin Ke's avatar
Guolin Ke committed
543
544
545
546
547
548
}

#' @rdname dimnames.lgb.Dataset
#' @export
`dimnames<-.lgb.Dataset` <- function(x, value) {
  if (!is.list(value) || length(value) != 2L)
549
550
    stop("invalid ", sQuote("value"), " given: must be a list of two elements")
  if (!is.null(value[[1L]])) { stop("lgb.Dataset does not have rownames") }
Guolin Ke's avatar
Guolin Ke committed
551
552
553
554
555
556
  if (is.null(value[[2]])) {
    x$set_colnames(NULL)
    return(x)
  }
  if (ncol(x) != length(value[[2]]))
    stop("can't assign ",
557
558
559
         sQuote(length(value[[2]])),
         " colnames to an lgb.Dataset with ",
         sQuote(ncol(x)), " columns")
Guolin Ke's avatar
Guolin Ke committed
560
  x$set_colnames(value[[2]])
561
  x
Guolin Ke's avatar
Guolin Ke committed
562
563
}

564
#' Slice a dataset
Guolin Ke's avatar
Guolin Ke committed
565
#'
566
#' Get a new \code{lgb.Dataset} containing the specified rows of
Guolin Ke's avatar
Guolin Ke committed
567
568
569
570
571
572
573
574
#' orginal lgb.Dataset object
#'
#' @param dataset Object of class "lgb.Dataset"
#' @param idxset a integer vector of indices of rows needed
#' @param ... other parameters (currently not used)
#' @return constructed sub dataset
#'
#' @examples
575
576
577
578
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
Guolin Ke's avatar
Guolin Ke committed
579
#'
580
581
582
#'   dsub <- slice(dtrain, 1:42)
#'   labels1 <- getinfo(dsub, 'label')
#' }
Guolin Ke's avatar
Guolin Ke committed
583
#' @export
584
slice <- function(dataset, ...) { UseMethod("slice") }
Guolin Ke's avatar
Guolin Ke committed
585
586
587
588

#' @rdname slice
#' @export
slice.lgb.Dataset <- function(dataset, idxset, ...) {
589
590
  if (!lgb.is.Dataset(dataset)) {
    stop("slice.lgb.Dataset: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
591
  }
592
  dataset$slice(idxset, ...)
Guolin Ke's avatar
Guolin Ke committed
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
}


#' Get information of an lgb.Dataset object
#'
#' @param dataset Object of class \code{lgb.Dataset}
#' @param name the name of the information field to get (see details)
#' @param ... other parameters
#' @return info data
#'
#' @details
#' The \code{name} field can be one of the following:
#'
#' \itemize{
#'     \item \code{label}: label lightgbm learn from ;
#'     \item \code{weight}: to do a weight rescale ;
#'     \item \code{group}: group size
#'     \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
#' }
#'
#' @examples
614
615
616
617
618
619
620
621
622
623
624
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.construct(dtrain)
#'   labels <- getinfo(dtrain, 'label')
#'   setinfo(dtrain, 'label', 1-labels)
#'
#'   labels2 <- getinfo(dtrain, 'label')
#'   stopifnot(all(labels2 == 1-labels))
#' }
Guolin Ke's avatar
Guolin Ke committed
625
#' @export
626
getinfo <- function(dataset, ...) { UseMethod("getinfo") }
Guolin Ke's avatar
Guolin Ke committed
627
628
629
630

#' @rdname getinfo
#' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) {
631
632
  if (!lgb.is.Dataset(dataset)) {
    stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
633
  }
634
  dataset$getinfo(name)
Guolin Ke's avatar
Guolin Ke committed
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
}

#' Set information of an lgb.Dataset object
#'
#' @param dataset Object of class "lgb.Dataset"
#' @param name the name of the field to get
#' @param info the specific field of information to set
#' @param ... other parameters
#' @return passed object
#'
#' @details
#' The \code{name} field can be one of the following:
#'
#' \itemize{
#'     \item \code{label}: label lightgbm learn from ;
#'     \item \code{weight}: to do a weight rescale ;
#'     \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
#'     \item \code{group}.
#' }
#'
#' @examples
656
657
658
659
660
661
662
663
664
665
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.construct(dtrain)
#'   labels <- getinfo(dtrain, 'label')
#'   setinfo(dtrain, 'label', 1-labels)
#'   labels2 <- getinfo(dtrain, 'label')
#'   stopifnot(all.equal(labels2, 1-labels))
#' }
Guolin Ke's avatar
Guolin Ke committed
666
#' @export
667
setinfo <- function(dataset, ...) { UseMethod("setinfo") }
Guolin Ke's avatar
Guolin Ke committed
668
669
670
671

#' @rdname setinfo
#' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
672
673
  if (!lgb.is.Dataset(dataset)) {
    stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
674
  }
675
  dataset$setinfo(name, info)
Guolin Ke's avatar
Guolin Ke committed
676
677
}

678
#' Set categorical feature of \code{lgb.Dataset}
Guolin Ke's avatar
Guolin Ke committed
679
680
681
682
683
#'
#' @param dataset object of class \code{lgb.Dataset}
#' @param categorical_feature categorical features
#' @return passed dataset
#' @examples
684
685
686
687
688
689
690
691
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.save(dtrain, 'lgb.Dataset.data')
#'   dtrain <- lgb.Dataset('lgb.Dataset.data')
#'   lgb.Dataset.set.categorical(dtrain, 1:2)
#' }
Guolin Ke's avatar
Guolin Ke committed
692
693
#' @rdname lgb.Dataset.set.categorical
#' @export
694
695
696
lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
697
  }
698
699
  dataset$set_categorical_feature(categorical_feature)
}
Guolin Ke's avatar
Guolin Ke committed
700

701
#' Set reference of \code{lgb.Dataset}
Guolin Ke's avatar
Guolin Ke committed
702
#'
703
#' If you want to use validation data, you should set reference to training data
Guolin Ke's avatar
Guolin Ke committed
704
705
706
707
708
#'
#' @param dataset object of class \code{lgb.Dataset}
#' @param reference object of class \code{lgb.Dataset}
#' @return passed dataset
#' @examples
709
710
711
712
713
714
715
716
717
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   data(agaricus.test, package='lightgbm')
#'   test <- agaricus.test
#'   dtest <- lgb.Dataset(test$data, test=train$label)
#'   lgb.Dataset.set.reference(dtest, dtrain)
#' }
Guolin Ke's avatar
Guolin Ke committed
718
719
720
#' @rdname lgb.Dataset.set.reference
#' @export
lgb.Dataset.set.reference <- function(dataset, reference) {
721
722
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
723
  }
724
  dataset$set_reference(reference)
Guolin Ke's avatar
Guolin Ke committed
725
726
}

727
728
#' Save \code{lgb.Dataset} to a binary file
#'
Guolin Ke's avatar
Guolin Ke committed
729
730
731
732
#' @param dataset object of class \code{lgb.Dataset}
#' @param fname object filename of output file
#' @return passed dataset
#' @examples
733
734
735
736
737
738
#' \dontrun{
#'   data(agaricus.train, package='lightgbm')
#'   train <- agaricus.train
#'   dtrain <- lgb.Dataset(train$data, label=train$label)
#'   lgb.Dataset.save(dtrain, "data.bin")
#' }
Guolin Ke's avatar
Guolin Ke committed
739
740
741
#' @rdname lgb.Dataset.save
#' @export
lgb.Dataset.save <- function(dataset, fname) {
742
743
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
744
  }
745
746
  if (!is.character(fname)) {
    stop("lgb.Dataset.set: fname should be a character or a file connection")
Guolin Ke's avatar
Guolin Ke committed
747
  }
748
  dataset$save_binary(fname)
Guolin Ke's avatar
Guolin Ke committed
749
}