lightgbm.R 4.08 KB
Newer Older
1
2
3
#' Simple interface for training an lightgbm model.
#' Its documentation is combined with lgb.train.
#'
Guolin Ke's avatar
Guolin Ke committed
4
5
#' @rdname lgb.train
#' @export
6
7
8
9
10
11
12
13
14
15
16
17
18
19
lightgbm <- function(data,
                     label = NULL,
                     weight = NULL,
                     params = list(),
                     nrounds = 10,
                     verbose = 1,
                     eval_freq = 1L,
                     early_stopping_rounds = NULL,
                     save_name = "lightgbm.model",
                     init_model = NULL,
                     callbacks = list(),
                     ...) {
  
  # Set data to a temporary variable
Guolin Ke's avatar
Guolin Ke committed
20
  dtrain <- data
21
22
23
  if (nrounds <= 0) {
    stop("nrounds should be greater than zero")
  }
24
  # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually
25
26
  if (!lgb.is.Dataset(dtrain)) {
    dtrain <- lgb.Dataset(data, label = label, weight = weight)
Guolin Ke's avatar
Guolin Ke committed
27
  }
Guolin Ke's avatar
Guolin Ke committed
28

29
  # Set validation as oneself
Guolin Ke's avatar
Guolin Ke committed
30
  valids <- list()
31
32
33
34
35
  if (verbose > 0) {
    valids$train = dtrain
  }
  
  # Train a model using the regular way
36
  bst <- lgb.train(params, dtrain, nrounds, valids, verbose = verbose, eval_freq = eval_freq,
Guolin Ke's avatar
Guolin Ke committed
37
38
                   early_stopping_rounds = early_stopping_rounds,
                   init_model = init_model, callbacks = callbacks, ...)
39
40
  
  # Store model under a specific name
Guolin Ke's avatar
Guolin Ke committed
41
  bst$save_model(save_name)
42
43
44
  
  # Return booster
  return(bst)
Guolin Ke's avatar
Guolin Ke committed
45
46
47
}

#' Training part from Mushroom Data Set
48
#'
Guolin Ke's avatar
Guolin Ke committed
49
50
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
51
#'
Guolin Ke's avatar
Guolin Ke committed
52
#' This data set includes the following fields:
53
#'
Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
58
59
60
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
61
62
63
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
64
#' School of Information and Computer Science.
65
#'
Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
#' @docType data
#' @keywords datasets
#' @name agaricus.train
#' @usage data(agaricus.train)
70
#' @format A list containing a label vector, and a dgCMatrix object with 6513
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
75
76
77
#' rows and 127 variables
NULL

#' Test part from Mushroom Data Set
#'
#' This data set is originally from the Mushroom data set,
#' UCI Machine Learning Repository.
78
#'
Guolin Ke's avatar
Guolin Ke committed
79
#' This data set includes the following fields:
80
#'
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85
86
87
#' \itemize{
#'  \item \code{label} the label for each record
#'  \item \code{data} a sparse Matrix of \code{dgCMatrix} class, with 126 columns.
#' }
#'
#' @references
#' https://archive.ics.uci.edu/ml/datasets/Mushroom
88
89
90
#'
#' Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository
#' [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California,
Guolin Ke's avatar
Guolin Ke committed
91
#' School of Information and Computer Science.
92
#'
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
#' @docType data
#' @keywords datasets
#' @name agaricus.test
#' @usage data(agaricus.test)
97
#' @format A list containing a label vector, and a dgCMatrix object with 1611
Guolin Ke's avatar
Guolin Ke committed
98
99
100
#' rows and 126 variables
NULL

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#' Bank Marketing Data Set
#'
#' This data set is originally from the Bank Marketing data set,
#' UCI Machine Learning Repository.
#'
#' It contains only the following: bank.csv with 10% of the examples and 17 inputs,
#' randomly selected from 3 (older version of this dataset with less inputs).
#'
#' @references
#' http://archive.ics.uci.edu/ml/datasets/Bank+Marketing
#' 
#' S. Moro, P. Cortez and P. Rita. (2014)
#' A Data-Driven Approach to Predict the Success of Bank Telemarketing. Decision Support Systems
#'
#' @docType data
#' @keywords datasets
#' @name bank
#' @usage data(bank)
#' @format A data.table with 4521 rows and 17 variables
NULL

Guolin Ke's avatar
Guolin Ke committed
122
# Various imports
Guolin Ke's avatar
Guolin Ke committed
123
#' @import methods
Guolin Ke's avatar
Guolin Ke committed
124
#' @importFrom R6 R6Class
James Lamb's avatar
James Lamb committed
125
#' @useDynLib lib_lightgbm
126
NULL
James Lamb's avatar
James Lamb committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

# Suppress false positive warnings from R CMD CHECK about
# "unrecognized global variable"
globalVariables(c(
    "."
    , ".N"
    , ".SD"
    , "Contribution"
    , "Cover"
    , "Feature"
    , "Frequency"
    , "Gain"
    , "internal_count"
    , "internal_value"
    , "leaf_index"
    , "leaf_parent"
    , "leaf_value"
    , "node_parent"
    , "split_feature"
    , "split_gain"
    , "split_index"
    , "tree_index"
))