lgb.Dataset.R 30.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
Dataset <- R6Class(
2
  classname = "lgb.Dataset",
3
  cloneable = FALSE,
Guolin Ke's avatar
Guolin Ke committed
4
  public = list(
5
6
    
    # Finalize will free up the handles
Guolin Ke's avatar
Guolin Ke committed
7
    finalize = function() {
8
9
      
      # Check the need for freeing handle
Guolin Ke's avatar
Guolin Ke committed
10
      if (!lgb.is.null.handle(private$handle)) {
11
12
        
        # Freeing up handle
Guolin Ke's avatar
Guolin Ke committed
13
14
        lgb.call("LGBM_DatasetFree_R", ret = NULL, private$handle)
        private$handle <- NULL
15
        
Guolin Ke's avatar
Guolin Ke committed
16
      }
17
      
Guolin Ke's avatar
Guolin Ke committed
18
    },
19
20
    
    # Initialize will create a starter dataset
Guolin Ke's avatar
Guolin Ke committed
21
    initialize = function(data,
22
23
24
                          params = list(),
                          reference = NULL,
                          colnames = NULL,
25
                          categorical_feature = NULL,
26
27
28
29
                          predictor = NULL,
                          free_raw_data = TRUE,
                          used_indices = NULL,
                          info = list(),
Guolin Ke's avatar
Guolin Ke committed
30
                          ...) {
31
32
      
      # Check for additional parameters
33
      additional_params <- list(...)
34
35
36
37
38
      
      # Create known attributes list
      INFO_KEYS <- c("label", "weight", "init_score", "group")
      
      # Check if attribute key is in the known attribute list
39
      for (key in names(additional_params)) {
40
41
        
        # Key existing
42
        if (key %in% INFO_KEYS) {
43
44
          
          # Store as info
45
          info[[key]] <- additional_params[[key]]
46
          
Guolin Ke's avatar
Guolin Ke committed
47
        } else {
48
49
          
          # Store as param
50
          params[[key]] <- additional_params[[key]]
51
          
Guolin Ke's avatar
Guolin Ke committed
52
        }
53
        
Guolin Ke's avatar
Guolin Ke committed
54
      }
55
56
      
      # Check for dataset reference
Guolin Ke's avatar
Guolin Ke committed
57
58
      if (!is.null(reference)) {
        if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
59
          stop("lgb.Dataset: Can only use ", sQuote("lgb.Dataset"), " as reference")
Guolin Ke's avatar
Guolin Ke committed
60
61
        }
      }
62
63
      
      # Check for predictor reference
Guolin Ke's avatar
Guolin Ke committed
64
65
      if (!is.null(predictor)) {
        if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
66
          stop("lgb.Dataset: Only can use ", sQuote("lgb.Predictor"), " as predictor")
Guolin Ke's avatar
Guolin Ke committed
67
68
        }
      }
69
      
70
71
72
73
74
75
76
77
      # Check for matrix format
      if (is.matrix(data)) {
        # Check whether matrix is the correct type first ("double")
        if (storage.mode(data) != "double") {
          storage.mode(data) <- "double"
        }
      }
      
78
79
80
      # Setup private attributes
      private$raw_data <- data
      private$params <- params
Guolin Ke's avatar
Guolin Ke committed
81
      private$reference <- reference
82
      private$colnames <- colnames
83

84
      private$categorical_feature <- categorical_feature
85
86
87
88
89
      private$predictor <- predictor
      private$free_raw_data <- free_raw_data
      private$used_indices <- used_indices
      private$info <- info
      
Guolin Ke's avatar
Guolin Ke committed
90
    },
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    
    create_valid = function(data,
                            info = list(),
                            ...) {
      
      # Create new dataset
      ret <- Dataset$new(data,
                         private$params,
                         self,
                         private$colnames,
                         private$categorical_feature,
                         private$predictor,
                         private$free_raw_data,
                         NULL,
                         info,
                         ...)
      
      # Return ret
109
      return(invisible(ret))
110
      
Guolin Ke's avatar
Guolin Ke committed
111
    },
112
113
    
    # Dataset constructor
Guolin Ke's avatar
Guolin Ke committed
114
    construct = function() {
115
116
      
      # Check for handle null
Guolin Ke's avatar
Guolin Ke committed
117
      if (!lgb.is.null.handle(private$handle)) {
118
        return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
119
      }
120
      
Guolin Ke's avatar
Guolin Ke committed
121
122
      # Get feature names
      cnames <- NULL
123
      if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
Guolin Ke's avatar
Guolin Ke committed
124
125
        cnames <- colnames(private$raw_data)
      }
126
      
Guolin Ke's avatar
Guolin Ke committed
127
      # set feature names if not exist
128
      if (is.null(private$colnames) && !is.null(cnames)) {
Guolin Ke's avatar
Guolin Ke committed
129
130
        private$colnames <- as.character(cnames)
      }
131
      
132
133
      # Get categorical feature index
      if (!is.null(private$categorical_feature)) {
134
135
        
        # Check for character name
136
        if (is.character(private$categorical_feature)) {
137
          
138
            cate_indices <- as.list(match(private$categorical_feature, private$colnames) - 1)
139
140
            
            # Provided indices, but some indices are not existing?
141
142
143
            if (sum(is.na(cate_indices)) > 0) {
              stop("lgb.self.get.handle: supplied an unknown feature in categorical_feature: ", sQuote(private$categorical_feature[is.na(cate_indices)]))
            }
144
            
145
          } else {
146
147
            
            # Check if more categorical features were output over the feature space
148
149
150
            if (max(private$categorical_feature) > length(private$colnames)) {
              stop("lgb.self.get.handle: supplied a too large value in categorical_feature: ", max(private$categorical_feature), " but only ", length(private$colnames), " features")
            }
151
152
            
            # Store indices as [0, n-1] indexed instead of [1, n] indexed
153
            cate_indices <- as.list(private$categorical_feature - 1)
154
            
155
          }
156
157
        
        # Store indices for categorical features
158
        private$params$categorical_feature <- cate_indices
159
        
160
      }
161
      
Guolin Ke's avatar
Guolin Ke committed
162
163
      # Check has header or not
      has_header <- FALSE
164
165
      if (!is.null(private$params$has_header) || !is.null(private$params$header)) {
        if (tolower(as.character(private$params$has_header)) == "true" || tolower(as.character(private$params$header)) == "true") {
Guolin Ke's avatar
Guolin Ke committed
166
167
168
          has_header <- TRUE
        }
      }
169
      
Guolin Ke's avatar
Guolin Ke committed
170
171
      # Generate parameter str
      params_str <- lgb.params2str(private$params)
172
173
      
      # Get handle of reference dataset
Guolin Ke's avatar
Guolin Ke committed
174
175
176
177
      ref_handle <- NULL
      if (!is.null(private$reference)) {
        ref_handle <- private$reference$.__enclos_env__$private$get_handle()
      }
178
      handle <- NA_real_
179
180
      
      # Not subsetting
Guolin Ke's avatar
Guolin Ke committed
181
      if (is.null(private$used_indices)) {
182
183
        
        # Are we using a data file?
184
        if (is.character(private$raw_data)) {
185
186
187
188
189
190
191
          
          handle <- lgb.call("LGBM_DatasetCreateFromFile_R",
                             ret = handle,
                             lgb.c_str(private$raw_data),
                             params_str,
                             ref_handle)
          
Guolin Ke's avatar
Guolin Ke committed
192
        } else if (is.matrix(private$raw_data)) {
193
194
195
196
197
198
199
200
201
202
          
          # Are we using a matrix?
          handle <- lgb.call("LGBM_DatasetCreateFromMat_R",
                             ret = handle,
                             private$raw_data,
                             nrow(private$raw_data),
                             ncol(private$raw_data),
                             params_str,
                             ref_handle)
          
203
        } else if (is(private$raw_data, "dgCMatrix")) {
204
205
206
          if (length(private$raw_data@p) > 2147483647) {
            stop("Cannot support large CSC matrix")
          }
207
208
209
210
211
212
213
214
215
216
217
218
          # Are we using a dgCMatrix (sparsed matrix column compressed)
          handle <- lgb.call("LGBM_DatasetCreateFromCSC_R",
                             ret = handle,
                             private$raw_data@p,
                             private$raw_data@i,
                             private$raw_data@x,
                             length(private$raw_data@p),
                             length(private$raw_data@x),
                             nrow(private$raw_data),
                             params_str,
                             ref_handle)
          
Guolin Ke's avatar
Guolin Ke committed
219
        } else {
220
221
222
223
          
          # Unknown data type
          stop("lgb.Dataset.construct: does not support constructing from ", sQuote(class(private$raw_data)))
          
Guolin Ke's avatar
Guolin Ke committed
224
        }
225
        
Guolin Ke's avatar
Guolin Ke committed
226
      } else {
227
228
        
        # Reference is empty
Guolin Ke's avatar
Guolin Ke committed
229
        if (is.null(private$reference)) {
230
          stop("lgb.Dataset.construct: reference cannot be NULL for constructing data subset")
Guolin Ke's avatar
Guolin Ke committed
231
        }
232
233
234
235
236
237
238
239
240
        
        # Construct subset
        handle <- lgb.call("LGBM_DatasetGetSubset_R",
                           ret = handle,
                           ref_handle,
                           private$used_indices,
                           length(private$used_indices),
                           params_str)
        
Guolin Ke's avatar
Guolin Ke committed
241
      }
Guolin Ke's avatar
Guolin Ke committed
242
243
244
      if (lgb.is.null.handle(handle)) {
        stop("lgb.Dataset.construct: cannot create Dataset handle")
      }
245
      # Setup class and private type
Guolin Ke's avatar
Guolin Ke committed
246
247
      class(handle) <- "lgb.Dataset.handle"
      private$handle <- handle
248
249
250
251
252
      
      # Set feature names
      if (!is.null(private$colnames)) {
        self$set_colnames(private$colnames)
      }
253

254
255
256
257
      # Load init score if requested
      if (!is.null(private$predictor) && is.null(private$used_indices)) {
        
        # Setup initial scores
258
        init_score <- private$predictor$predict(private$raw_data, rawscore = TRUE, reshape = TRUE)
259
260
        
        # Not needed to transpose, for is col_marjor
Guolin Ke's avatar
Guolin Ke committed
261
262
        init_score <- as.vector(init_score)
        private$info$init_score <- init_score
263
264
265
266
267
268
        
      }
      
      # Should we free raw data?
      if (isTRUE(private$free_raw_data)) {
        private$raw_data <- NULL
Guolin Ke's avatar
Guolin Ke committed
269
      }
270
271
      
      # Get private information
Guolin Ke's avatar
Guolin Ke committed
272
      if (length(private$info) > 0) {
273
274
        
        # Set infos
275
        for (i in seq_along(private$info)) {
276
          
Guolin Ke's avatar
Guolin Ke committed
277
278
          p <- private$info[i]
          self$setinfo(names(p), p[[1]])
279
          
Guolin Ke's avatar
Guolin Ke committed
280
        }
281
        
Guolin Ke's avatar
Guolin Ke committed
282
      }
283
284
      
      # Get label information existence
Guolin Ke's avatar
Guolin Ke committed
285
286
287
      if (is.null(self$getinfo("label"))) {
        stop("lgb.Dataset.construct: label should be set")
      }
288
      
289
290
      # Return self
      return(invisible(self))
291
      
Guolin Ke's avatar
Guolin Ke committed
292
    },
293
294
    
    # Dimension function
Guolin Ke's avatar
Guolin Ke committed
295
    dim = function() {
296
297
      
      # Check for handle
Guolin Ke's avatar
Guolin Ke committed
298
      if (!lgb.is.null.handle(private$handle)) {
299
        
300
301
        num_row <- 0L
        num_col <- 0L
302
303
304
305
306
        
        # Get numeric data and numeric features
        c(lgb.call("LGBM_DatasetGetNumData_R", ret = num_row, private$handle),
          lgb.call("LGBM_DatasetGetNumFeature_R", ret = num_col, private$handle))
        
307
      } else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
308
309
        
        # Check if dgCMatrix (sparse matrix column compressed)
310
        dim(private$raw_data)
311
        
Guolin Ke's avatar
Guolin Ke committed
312
      } else {
313
314
315
316
        
        # Trying to work with unknown dimensions is not possible
        stop("dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly")
        
Guolin Ke's avatar
Guolin Ke committed
317
      }
318
      
Guolin Ke's avatar
Guolin Ke committed
319
    },
320
321
    
    # Get column names
Guolin Ke's avatar
Guolin Ke committed
322
    get_colnames = function() {
323
324
      
      # Check for handle
Guolin Ke's avatar
Guolin Ke committed
325
      if (!lgb.is.null.handle(private$handle)) {
326
327
        
        # Get feature names and write them
328
329
330
        cnames <- lgb.call.return.str("LGBM_DatasetGetFeatureNames_R", private$handle)
        private$colnames <- as.character(base::strsplit(cnames, "\t")[[1]])
        private$colnames
331
        
332
      } else if (is.matrix(private$raw_data) || is(private$raw_data, "dgCMatrix")) {
333
334
        
        # Check if dgCMatrix (sparse matrix column compressed)
335
        colnames(private$raw_data)
336
        
Guolin Ke's avatar
Guolin Ke committed
337
      } else {
338
339
340
341
        
        # Trying to work with unknown dimensions is not possible
        stop("dim: cannot get dimensions before dataset has been constructed, please call lgb.Dataset.construct explicitly")
        
Guolin Ke's avatar
Guolin Ke committed
342
      }
343
      
Guolin Ke's avatar
Guolin Ke committed
344
    },
345
346
    
    # Set column names
Guolin Ke's avatar
Guolin Ke committed
347
    set_colnames = function(colnames) {
348
349
350
      
      # Check column names non-existence
      if (is.null(colnames)) {
351
        return(invisible(self))
352
353
354
      }
      
      # Check empty column names
Guolin Ke's avatar
Guolin Ke committed
355
      colnames <- as.character(colnames)
356
      if (length(colnames) == 0) {
357
        return(invisible(self))
358
359
360
      }
      
      # Write column names
Guolin Ke's avatar
Guolin Ke committed
361
362
      private$colnames <- colnames
      if (!lgb.is.null.handle(private$handle)) {
363
364
        
        # Merge names with tab separation
Guolin Ke's avatar
Guolin Ke committed
365
366
367
368
369
        merged_name <- paste0(as.list(private$colnames), collapse = "\t")
        lgb.call("LGBM_DatasetSetFeatureNames_R",
                 ret = NULL,
                 private$handle,
                 lgb.c_str(merged_name))
370
        
Guolin Ke's avatar
Guolin Ke committed
371
      }
372
373
      
      # Return self
374
      return(invisible(self))
375
      
Guolin Ke's avatar
Guolin Ke committed
376
    },
377
378
    
    # Get information
Guolin Ke's avatar
Guolin Ke committed
379
    getinfo = function(name) {
380
381
      
      # Create known attributes list
382
      INFONAMES <- c("label", "weight", "init_score", "group")
383
384
385
386
      
      # Check if attribute key is in the known attribute list
      if (!is.character(name) || length(name) != 1 || !name %in% INFONAMES) {
        stop("getinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", "))
Guolin Ke's avatar
Guolin Ke committed
387
      }
388
389
      
      # Check for info name and handle
390
391
392
393
      if (is.null(private$info[[name]])) {
        if (lgb.is.null.handle(private$handle)){
          stop("Cannot perform getinfo before construct Dataset.")
        }
394
        # Get field size of info
395
        info_len <- 0L
396
397
398
399
        info_len <- lgb.call("LGBM_DatasetGetFieldSize_R",
                             ret = info_len,
                             private$handle,
                             lgb.c_str(name))
400
401
        
        # Check if info is not empty
Guolin Ke's avatar
Guolin Ke committed
402
        if (info_len > 0) {
403
404
          
          # Get back fields
Guolin Ke's avatar
Guolin Ke committed
405
          ret <- NULL
406
407
408
409
410
411
          ret <- if (name == "group") {
            integer(info_len) # Integer
          } else {
            numeric(info_len) # Numeric
          }
          
412
413
414
415
          ret <- lgb.call("LGBM_DatasetGetField_R",
                          ret = ret,
                          private$handle,
                          lgb.c_str(name))
416
          
Guolin Ke's avatar
Guolin Ke committed
417
          private$info[[name]] <- ret
418
          
Guolin Ke's avatar
Guolin Ke committed
419
420
        }
      }
421
      
422
      private$info[[name]]
423
      
Guolin Ke's avatar
Guolin Ke committed
424
    },
425
426
    
    # Set information
Guolin Ke's avatar
Guolin Ke committed
427
    setinfo = function(name, info) {
428
429
      
      # Create known attributes list
430
      INFONAMES <- c("label", "weight", "init_score", "group")
431
432
433
434
435
436
437
438
439
440
441
442
443
444
      
      # Check if attribute key is in the known attribute list
      if (!is.character(name) || length(name) != 1 || !name %in% INFONAMES) {
        stop("setinfo: name must one of the following: ", paste0(sQuote(INFONAMES), collapse = ", "))
      }
      
      # Check for type of information
      info <- if (name == "group") {
        as.integer(info) # Integer
      } else {
        as.numeric(info) # Numeric
      }
      
      # Store information privately
Guolin Ke's avatar
Guolin Ke committed
445
      private$info[[name]] <- info
446
      
447
      if (!lgb.is.null.handle(private$handle) && !is.null(info)) {
448
        
Guolin Ke's avatar
Guolin Ke committed
449
        if (length(info) > 0) {
450
451
452
453
454
455
456
457
          
          lgb.call("LGBM_DatasetSetField_R",
                   ret = NULL,
                   private$handle,
                   lgb.c_str(name),
                   info,
                   length(info))
          
Guolin Ke's avatar
Guolin Ke committed
458
        }
459
        
Guolin Ke's avatar
Guolin Ke committed
460
      }
461
462
      
      # Return self
463
      return(invisible(self))
464
      
Guolin Ke's avatar
Guolin Ke committed
465
    },
466
467
    
    # Slice dataset
Guolin Ke's avatar
Guolin Ke committed
468
    slice = function(idxset, ...) {
469
470
471
472
473
474
475
476
477
478
479
480
481
      
      # Perform slicing
      Dataset$new(NULL,
                  private$params,
                  self,
                  private$colnames,
                  private$categorical_feature,
                  private$predictor,
                  private$free_raw_data,
                  idxset,
                  NULL,
                  ...)
      
Guolin Ke's avatar
Guolin Ke committed
482
    },
483
484
    
    # Update parameters
485
    update_params = function(params) {
486
487
      
      # Parameter updating
Guolin Ke's avatar
Guolin Ke committed
488
      private$params <- modifyList(private$params, params)
489
      return(invisible(self))
490
      
Guolin Ke's avatar
Guolin Ke committed
491
    },
492
493
    
    # Set categorical feature parameter
494
    set_categorical_feature = function(categorical_feature) {
495
496
497
      
      # Check for identical input
      if (identical(private$categorical_feature, categorical_feature)) {
498
        return(invisible(self))
499
500
501
      }
      
      # Check for empty data
502
      if (is.null(private$raw_data)) {
503
504
        stop("set_categorical_feature: cannot set categorical feature after freeing raw data,
          please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
505
      }
506
507
      
      # Overwrite categorical features
508
      private$categorical_feature <- categorical_feature
509
510
      
      # Finalize and return self
511
      self$finalize()
512
      return(invisible(self))
513
      
514
    },
515
516
    
    # Set reference
Guolin Ke's avatar
Guolin Ke committed
517
    set_reference = function(reference) {
518
519
      
      # Set known references
520
      self$set_categorical_feature(reference$.__enclos_env__$private$categorical_feature)
Guolin Ke's avatar
Guolin Ke committed
521
522
      self$set_colnames(reference$get_colnames())
      private$set_predictor(reference$.__enclos_env__$private$predictor)
523
524
525
      
      # Check for identical references
      if (identical(private$reference, reference)) {
526
        return(invisible(self))
527
528
529
      }
      
      # Check for empty data
Guolin Ke's avatar
Guolin Ke committed
530
      if (is.null(private$raw_data)) {
531
532
533
534
        
        stop("set_reference: cannot set reference after freeing raw data,
          please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
        
Guolin Ke's avatar
Guolin Ke committed
535
      }
536
537
      
      # Check for non-existing reference
Guolin Ke's avatar
Guolin Ke committed
538
      if (!is.null(reference)) {
539
540
        
        # Reference is unknown
Guolin Ke's avatar
Guolin Ke committed
541
        if (!lgb.check.r6.class(reference, "lgb.Dataset")) {
542
          stop("set_reference: Can only use lgb.Dataset as a reference")
Guolin Ke's avatar
Guolin Ke committed
543
        }
544
        
Guolin Ke's avatar
Guolin Ke committed
545
      }
546
547
      
      # Store reference
Guolin Ke's avatar
Guolin Ke committed
548
      private$reference <- reference
549
550
      
      # Finalize and return self
Guolin Ke's avatar
Guolin Ke committed
551
      self$finalize()
552
      return(invisible(self))
553
      
Guolin Ke's avatar
Guolin Ke committed
554
    },
555
556
    
    # Save binary model
Guolin Ke's avatar
Guolin Ke committed
557
    save_binary = function(fname) {
558
559
      
      # Store binary data
Guolin Ke's avatar
Guolin Ke committed
560
561
562
563
564
      self$construct()
      lgb.call("LGBM_DatasetSaveBinary_R",
               ret = NULL,
               private$handle,
               lgb.c_str(fname))
565
      return(invisible(self))
Guolin Ke's avatar
Guolin Ke committed
566
    }
567
    
Guolin Ke's avatar
Guolin Ke committed
568
569
  ),
  private = list(
570
571
572
573
574
    handle = NULL,
    raw_data = NULL,
    params = list(),
    reference = NULL,
    colnames = NULL,
575
    categorical_feature = NULL,
576
577
578
579
580
581
582
583
584
585
586
587
    predictor = NULL,
    free_raw_data = TRUE,
    used_indices = NULL,
    info = NULL,
    
    # Get handle
    get_handle = function() {
      
      # Get handle and construct if needed
      if (lgb.is.null.handle(private$handle)) {
        self$construct()
      }
588
      private$handle
589
      
Guolin Ke's avatar
Guolin Ke committed
590
    },
591
592
    
    # Set predictor
Guolin Ke's avatar
Guolin Ke committed
593
    set_predictor = function(predictor) {
594
595
596
      
      # Return self is identical predictor
      if (identical(private$predictor, predictor)) {
597
        return(invisible(self))
598
599
600
      }
      
      # Check for empty data
Guolin Ke's avatar
Guolin Ke committed
601
      if (is.null(private$raw_data)) {
602
603
        stop("set_predictor: cannot set predictor after free raw data,
          please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")
Guolin Ke's avatar
Guolin Ke committed
604
      }
605
606
      
      # Check for empty predictor
Guolin Ke's avatar
Guolin Ke committed
607
      if (!is.null(predictor)) {
608
609
        
        # Predictor is unknown
Guolin Ke's avatar
Guolin Ke committed
610
        if (!lgb.check.r6.class(predictor, "lgb.Predictor")) {
611
          stop("set_predictor: Can only use lgb.Predictor as predictor")
Guolin Ke's avatar
Guolin Ke committed
612
        }
613
        
Guolin Ke's avatar
Guolin Ke committed
614
      }
615
616
      
      # Store predictor
Guolin Ke's avatar
Guolin Ke committed
617
      private$predictor <- predictor
618
619
      
      # Finalize and return self
Guolin Ke's avatar
Guolin Ke committed
620
      self$finalize()
621
      return(invisible(self))
622
      
Guolin Ke's avatar
Guolin Ke committed
623
    }
624
    
Guolin Ke's avatar
Guolin Ke committed
625
626
627
  )
)

wxchan's avatar
wxchan committed
628
#' Construct lgb.Dataset object
Guolin Ke's avatar
Guolin Ke committed
629
#'
wxchan's avatar
wxchan committed
630
#' Construct lgb.Dataset object from dense matrix, sparse matrix
Guolin Ke's avatar
Guolin Ke committed
631
632
633
634
635
636
#' or local file (that was created previously by saving an \code{lgb.Dataset}).
#'
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param params a list of parameters
#' @param reference reference dataset
#' @param colnames names of columns
637
#' @param categorical_feature categorical features
Guolin Ke's avatar
Guolin Ke committed
638
639
640
#' @param free_raw_data TRUE for need to free raw data after construct
#' @param info a list of information of the lgb.Dataset object
#' @param ... other information to pass to \code{info} or parameters pass to \code{params}
641
#' 
Guolin Ke's avatar
Guolin Ke committed
642
#' @return constructed dataset
643
#' 
Guolin Ke's avatar
Guolin Ke committed
644
#' @examples
645
#' \dontrun{
646
647
648
649
650
651
652
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.save(dtrain, "lgb.Dataset.data")
#' dtrain <- lgb.Dataset("lgb.Dataset.data")
#' lgb.Dataset.construct(dtrain)
653
#' }
654
#' 
Guolin Ke's avatar
Guolin Ke committed
655
656
#' @export
lgb.Dataset <- function(data,
657
658
659
                        params = list(),
                        reference = NULL,
                        colnames = NULL,
660
                        categorical_feature = NULL,
661
662
                        free_raw_data = TRUE,
                        info = list(),
Guolin Ke's avatar
Guolin Ke committed
663
                        ...) {
664
665
  
  # Create new dataset
666
  invisible(Dataset$new(data,
667
668
669
670
671
672
673
674
              params,
              reference,
              colnames,
              categorical_feature,
              NULL,
              free_raw_data,
              NULL,
              info,
675
              ...))
676
  
Guolin Ke's avatar
Guolin Ke committed
677
678
}

wxchan's avatar
wxchan committed
679
#' Construct validation data
680
#' 
wxchan's avatar
wxchan committed
681
#' Construct validation data according to training data
682
#' 
Guolin Ke's avatar
Guolin Ke committed
683
684
685
686
#' @param dataset \code{lgb.Dataset} object, training data
#' @param data a \code{matrix} object, a \code{dgCMatrix} object or a character representing a filename
#' @param info a list of information of the lgb.Dataset object
#' @param ... other information to pass to \code{info}.
687
#' 
Guolin Ke's avatar
Guolin Ke committed
688
#' @return constructed dataset
689
#' 
Guolin Ke's avatar
Guolin Ke committed
690
#' @examples
691
#' \dontrun{
692
693
694
695
696
697
698
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
699
#' }
700
#' 
Guolin Ke's avatar
Guolin Ke committed
701
#' @export
702
703
704
lgb.Dataset.create.valid <- function(dataset, data, info = list(), ...) {
  
  # Check if dataset is not a dataset
705
706
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.create.valid: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
707
  }
708
709
  
  # Create validation dataset
710
  invisible(dataset$create_valid(data, info, ...))
711
  
712
}
Guolin Ke's avatar
Guolin Ke committed
713

714
#' Construct Dataset explicitly
715
#' 
Guolin Ke's avatar
Guolin Ke committed
716
#' @param dataset Object of class \code{lgb.Dataset}
717
#' 
Guolin Ke's avatar
Guolin Ke committed
718
#' @examples
719
#' \dontrun{
720
721
722
723
724
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
725
#' }
726
#' 
Guolin Ke's avatar
Guolin Ke committed
727
728
#' @export
lgb.Dataset.construct <- function(dataset) {
729
730
  
  # Check if dataset is not a dataset
731
732
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.construct: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
733
  }
734
735
  
  # Construct the dataset
736
  invisible(dataset$construct())
737
  
Guolin Ke's avatar
Guolin Ke committed
738
739
}

740
#' Dimensions of an lgb.Dataset
741
#' 
Guolin Ke's avatar
Guolin Ke committed
742
743
744
#' Returns a vector of numbers of rows and of columns in an \code{lgb.Dataset}.
#' @param x Object of class \code{lgb.Dataset}
#' @param ... other parameters
745
#' 
Guolin Ke's avatar
Guolin Ke committed
746
#' @return a vector of numbers of rows and of columns
747
#' 
Guolin Ke's avatar
Guolin Ke committed
748
749
750
#' @details
#' Note: since \code{nrow} and \code{ncol} internally use \code{dim}, they can also
#' be directly used with an \code{lgb.Dataset} object.
751
#' 
Guolin Ke's avatar
Guolin Ke committed
752
#' @examples
753
754
755
756
757
758
759
760
761
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' 
#' stopifnot(nrow(dtrain) == nrow(train$data))
#' stopifnot(ncol(dtrain) == ncol(train$data))
#' stopifnot(all(dim(dtrain) == dim(train$data)))
762
#' }
763
#' 
Guolin Ke's avatar
Guolin Ke committed
764
765
766
#' @rdname dim
#' @export
dim.lgb.Dataset <- function(x, ...) {
767
768
  
  # Check if dataset is not a dataset
769
770
  if (!lgb.is.Dataset(x)) {
    stop("dim.lgb.Dataset: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
771
  }
772
773
  
  # Return dimensions
774
  x$dim()
775
  
Guolin Ke's avatar
Guolin Ke committed
776
777
778
779
780
}

#' Handling of column names of \code{lgb.Dataset}
#'
#' Only column names are supported for \code{lgb.Dataset}, thus setting of
781
#' row names would have no effect and returned row names would be NULL.
Guolin Ke's avatar
Guolin Ke committed
782
783
784
785
786
787
788
789
790
791
#'
#' @param x object of class \code{lgb.Dataset}
#' @param value a list of two elements: the first one is ignored
#'        and the second one is column names
#'
#' @details
#' Generic \code{dimnames} methods are used by \code{colnames}.
#' Since row names are irrelevant, it is recommended to use \code{colnames} directly.
#'
#' @examples
792
793
794
795
796
797
798
799
800
801
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#' dimnames(dtrain)
#' colnames(dtrain)
#' colnames(dtrain) <- make.names(1:ncol(train$data))
#' print(dtrain, verbose = TRUE)
802
#' }
803
#' 
Guolin Ke's avatar
Guolin Ke committed
804
805
806
#' @rdname dimnames.lgb.Dataset
#' @export
dimnames.lgb.Dataset <- function(x) {
807
808
  
  # Check if dataset is not a dataset
809
810
  if (!lgb.is.Dataset(x)) {
    stop("dimnames.lgb.Dataset: input data should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
811
  }
812
813
  
  # Return dimension names
814
  list(NULL, x$get_colnames())
815
  
Guolin Ke's avatar
Guolin Ke committed
816
817
818
819
820
}

#' @rdname dimnames.lgb.Dataset
#' @export
`dimnames<-.lgb.Dataset` <- function(x, value) {
821
822
823
  
  # Check if invalid element list
  if (!is.list(value) || length(value) != 2L) {
824
    stop("invalid ", sQuote("value"), " given: must be a list of two elements")
825
826
827
828
829
830
831
832
  }
  
  # Check for unknown row names
  if (!is.null(value[[1L]])) {
    stop("lgb.Dataset does not have rownames")
  }
  
  # Check for second value missing
Guolin Ke's avatar
Guolin Ke committed
833
  if (is.null(value[[2]])) {
834
835
    
    # No column names
Guolin Ke's avatar
Guolin Ke committed
836
837
    x$set_colnames(NULL)
    return(x)
838
839
840
841
842
843
    
  }
  
  # Check for unmatching column size
  if (ncol(x) != length(value[[2]])) {
    stop("can't assign ", sQuote(length(value[[2]])), " colnames to an lgb.Dataset with ", sQuote(ncol(x)), " columns")
Guolin Ke's avatar
Guolin Ke committed
844
  }
845
846
  
  # Set column names properly, and return
Guolin Ke's avatar
Guolin Ke committed
847
  x$set_colnames(value[[2]])
848
  x
849
  
Guolin Ke's avatar
Guolin Ke committed
850
851
}

852
#' Slice a dataset
853
#' 
854
#' Get a new \code{lgb.Dataset} containing the specified rows of
Guolin Ke's avatar
Guolin Ke committed
855
#' orginal lgb.Dataset object
856
#' 
Guolin Ke's avatar
Guolin Ke committed
857
858
859
860
#' @param dataset Object of class "lgb.Dataset"
#' @param idxset a integer vector of indices of rows needed
#' @param ... other parameters (currently not used)
#' @return constructed sub dataset
861
#' 
Guolin Ke's avatar
Guolin Ke committed
862
#' @examples
863
#' \dontrun{
864
865
866
867
868
869
870
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' 
#' dsub <- lightgbm::slice(dtrain, 1:42)
#' labels <- lightgbm::getinfo(dsub, "label")
871
#' }
872
#' 
Guolin Ke's avatar
Guolin Ke committed
873
#' @export
874
875
876
slice <- function(dataset, ...) {
  UseMethod("slice")
}
Guolin Ke's avatar
Guolin Ke committed
877
878
879
880

#' @rdname slice
#' @export
slice.lgb.Dataset <- function(dataset, idxset, ...) {
881
882
  
  # Check if dataset is not a dataset
883
884
  if (!lgb.is.Dataset(dataset)) {
    stop("slice.lgb.Dataset: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
885
  }
886
887
  
  # Return sliced set
888
  invisible(dataset$slice(idxset, ...))
889
  
Guolin Ke's avatar
Guolin Ke committed
890
891
892
}

#' Get information of an lgb.Dataset object
893
#' 
Guolin Ke's avatar
Guolin Ke committed
894
895
896
897
#' @param dataset Object of class \code{lgb.Dataset}
#' @param name the name of the information field to get (see details)
#' @param ... other parameters
#' @return info data
898
#' 
Guolin Ke's avatar
Guolin Ke committed
899
900
#' @details
#' The \code{name} field can be one of the following:
901
#' 
Guolin Ke's avatar
Guolin Ke committed
902
903
904
905
906
907
#' \itemize{
#'     \item \code{label}: label lightgbm learn from ;
#'     \item \code{weight}: to do a weight rescale ;
#'     \item \code{group}: group size
#'     \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
#' }
908
#' 
Guolin Ke's avatar
Guolin Ke committed
909
#' @examples
910
#' \dontrun{
911
912
913
914
915
916
917
918
919
920
921
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#' 
#' labels <- lightgbm::getinfo(dtrain, "label")
#' lightgbm::setinfo(dtrain, "label", 1 - labels)
#' 
#' labels2 <- lightgbm::getinfo(dtrain, "label")
#' stopifnot(all(labels2 == 1 - labels))
922
#' }
923
#' 
Guolin Ke's avatar
Guolin Ke committed
924
#' @export
925
926
927
getinfo <- function(dataset, ...) {
  UseMethod("getinfo")
}
Guolin Ke's avatar
Guolin Ke committed
928
929
930
931

#' @rdname getinfo
#' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) {
932
933
  
  # Check if dataset is not a dataset
934
935
  if (!lgb.is.Dataset(dataset)) {
    stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
936
  }
937
938
  
  # Return information
939
  dataset$getinfo(name)
940
  
Guolin Ke's avatar
Guolin Ke committed
941
942
943
}

#' Set information of an lgb.Dataset object
944
#' 
Guolin Ke's avatar
Guolin Ke committed
945
946
947
948
949
#' @param dataset Object of class "lgb.Dataset"
#' @param name the name of the field to get
#' @param info the specific field of information to set
#' @param ... other parameters
#' @return passed object
950
#' 
Guolin Ke's avatar
Guolin Ke committed
951
952
#' @details
#' The \code{name} field can be one of the following:
953
#' 
Guolin Ke's avatar
Guolin Ke committed
954
955
956
957
958
959
#' \itemize{
#'     \item \code{label}: label lightgbm learn from ;
#'     \item \code{weight}: to do a weight rescale ;
#'     \item \code{init_score}: initial score is the base prediction lightgbm will boost from ;
#'     \item \code{group}.
#' }
960
#' 
Guolin Ke's avatar
Guolin Ke committed
961
#' @examples
962
#' \dontrun{
963
964
965
966
967
968
969
970
971
972
973
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#' 
#' labels <- lightgbm::getinfo(dtrain, "label")
#' lightgbm::setinfo(dtrain, "label", 1 - labels)
#' 
#' labels2 <- lightgbm::getinfo(dtrain, "label")
#' stopifnot(all.equal(labels2, 1 - labels))
974
#' }
975
#' 
Guolin Ke's avatar
Guolin Ke committed
976
#' @export
977
978
979
setinfo <- function(dataset, ...) {
  UseMethod("setinfo")
}
Guolin Ke's avatar
Guolin Ke committed
980
981
982
983

#' @rdname setinfo
#' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
984
985
  
  # Check if dataset is not a dataset
986
987
  if (!lgb.is.Dataset(dataset)) {
    stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
988
  }
989
990
  
  # Set information
991
  invisible(dataset$setinfo(name, info))
Guolin Ke's avatar
Guolin Ke committed
992
993
}

994
#' Set categorical feature of \code{lgb.Dataset}
995
#' 
996
997
#' @param dataset object of class \code{lgb.Dataset}
#' @param categorical_feature categorical features
998
#' 
999
#' @return passed dataset
1000
#' 
1001
1002
#' @examples
#' \dontrun{
1003
1004
1005
1006
1007
1008
1009
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.save(dtrain, "lgb.Dataset.data")
#' dtrain <- lgb.Dataset("lgb.Dataset.data")
#' lgb.Dataset.set.categorical(dtrain, 1:2)
1010
#' }
1011
#' 
1012
1013
1014
#' @rdname lgb.Dataset.set.categorical
#' @export
lgb.Dataset.set.categorical <- function(dataset, categorical_feature) {
1015
1016
  
  # Check if dataset is not a dataset
1017
1018
1019
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.set.categorical: input dataset should be an lgb.Dataset object")
  }
1020
1021
  
  # Set categoricals
1022
  invisible(dataset$set_categorical_feature(categorical_feature))
1023
  
1024
1025
}

1026
#' Set reference of \code{lgb.Dataset}
1027
#' 
1028
#' If you want to use validation data, you should set reference to training data
1029
#' 
Guolin Ke's avatar
Guolin Ke committed
1030
1031
#' @param dataset object of class \code{lgb.Dataset}
#' @param reference object of class \code{lgb.Dataset}
1032
#' 
Guolin Ke's avatar
Guolin Ke committed
1033
#' @return passed dataset
1034
#' 
Guolin Ke's avatar
Guolin Ke committed
1035
#' @examples
1036
#' \dontrun{
1037
1038
1039
1040
1041
1042
1043
1044
#' library(lightgbm)
#' data(agaricus.train, package ="lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset(test$data, test = train$label)
#' lgb.Dataset.set.reference(dtest, dtrain)
1045
#' }
1046
#' 
Guolin Ke's avatar
Guolin Ke committed
1047
1048
1049
#' @rdname lgb.Dataset.set.reference
#' @export
lgb.Dataset.set.reference <- function(dataset, reference) {
1050
1051
  
  # Check if dataset is not a dataset
1052
1053
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.set.reference: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
1054
  }
1055
1056
  
  # Set reference
1057
  invisible(dataset$set_reference(reference))
Guolin Ke's avatar
Guolin Ke committed
1058
1059
}

1060
#' Save \code{lgb.Dataset} to a binary file
1061
#' 
Guolin Ke's avatar
Guolin Ke committed
1062
1063
#' @param dataset object of class \code{lgb.Dataset}
#' @param fname object filename of output file
1064
#' 
Guolin Ke's avatar
Guolin Ke committed
1065
#' @return passed dataset
1066
#' 
Guolin Ke's avatar
Guolin Ke committed
1067
#' @examples
1068
#' 
1069
#' \dontrun{
1070
1071
1072
1073
1074
#' library(lightgbm)
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.save(dtrain, "data.bin")
1075
#' }
1076
#' 
Guolin Ke's avatar
Guolin Ke committed
1077
1078
1079
#' @rdname lgb.Dataset.save
#' @export
lgb.Dataset.save <- function(dataset, fname) {
1080
1081
  
  # Check if dataset is not a dataset
1082
1083
  if (!lgb.is.Dataset(dataset)) {
    stop("lgb.Dataset.set: input dataset should be an lgb.Dataset object")
Guolin Ke's avatar
Guolin Ke committed
1084
  }
1085
1086
  
  # File-type is not matching
1087
1088
  if (!is.character(fname)) {
    stop("lgb.Dataset.set: fname should be a character or a file connection")
Guolin Ke's avatar
Guolin Ke committed
1089
  }
1090
1091
  
  # Store binary
1092
  invisible(dataset$save_binary(fname))
Guolin Ke's avatar
Guolin Ke committed
1093
}