basic_walkthrough.R 4.53 KB
Newer Older
1
library(lightgbm)
Guolin Ke's avatar
Guolin Ke committed
2

3
# We load in the agaricus dataset
Guolin Ke's avatar
Guolin Ke committed
4
# In this example, we are aiming to predict whether a mushroom is edible
5
6
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
7
8
train <- agaricus.train
test <- agaricus.test
9
10

# The loaded data is stored in sparseMatrix, and label is a numeric vector in {0,1}
Guolin Ke's avatar
Guolin Ke committed
11
12
13
class(train$label)
class(train$data)

14
15
16
17
18
19
20
21
# Set parameters for model training
train_params <- list(
    num_leaves = 4L
    , learning_rate = 1.0
    , objective = "binary"
    , nthread = 2L
)

22
23
24
25
#--------------------Basic Training using lightgbm----------------
# This is the basic usage of lightgbm you can put matrix in data field
# Note: we are putting in sparse matrix here, lightgbm naturally handles sparse input
# Use sparse matrix when your feature is sparse (e.g. when you are using one-hot encoding vector)
Guolin Ke's avatar
Guolin Ke committed
26
print("Training lightgbm with sparseMatrix")
27
28
bst <- lightgbm(
    data = train$data
29
    , params = train_params
30
    , label = train$label
31
    , nrounds = 2L
32
)
33
34

# Alternatively, you can put in dense matrix, i.e. basic R-matrix
Guolin Ke's avatar
Guolin Ke committed
35
print("Training lightgbm with Matrix")
36
37
bst <- lightgbm(
    data = as.matrix(train$data)
38
    , params = train_params
39
    , label = train$label
40
    , nrounds = 2L
41
)
Guolin Ke's avatar
Guolin Ke committed
42

43
# You can also put in lgb.Dataset object, which stores label, data and other meta datas needed for advanced features
Guolin Ke's avatar
Guolin Ke committed
44
print("Training lightgbm with lgb.Dataset")
45
46
47
48
49
50
dtrain <- lgb.Dataset(
    data = train$data
    , label = train$label
)
bst <- lightgbm(
    data = dtrain
51
    , params = train_params
52
    , nrounds = 2L
53
)
Guolin Ke's avatar
Guolin Ke committed
54
55
56

# Verbose = 0,1,2
print("Train lightgbm with verbose 0, no message")
57
58
bst <- lightgbm(
    data = dtrain
59
    , params = train_params
60
61
    , nrounds = 2L
    , verbose = 0L
62
)
63

Guolin Ke's avatar
Guolin Ke committed
64
print("Train lightgbm with verbose 1, print evaluation metric")
65
66
bst <- lightgbm(
    data = dtrain
67
    , params = train_params
68
69
    , nrounds = 2L
    , verbose = 1L
70
)
71

Guolin Ke's avatar
Guolin Ke committed
72
print("Train lightgbm with verbose 2, also print information about tree")
73
74
bst <- lightgbm(
    data = dtrain
75
    , params = train_params
76
77
    , nrounds = 2L
    , verbose = 2L
78
)
Guolin Ke's avatar
Guolin Ke committed
79

80
81
# You can also specify data as file path to a LibSVM/TCV/CSV format input
# Since we do not have this file with us, the following line is just for illustration
82
83
# bst <- lightgbm(
#     data = "agaricus.train.svm"
84
85
86
#     , num_leaves = 4L
#     , learning_rate = 1.0
#     , nrounds = 2L
87
88
#     , objective = "binary"
# )
Guolin Ke's avatar
Guolin Ke committed
89

90
91
92
#--------------------Basic prediction using lightgbm--------------
# You can do prediction using the following line
# You can put in Matrix, sparseMatrix, or lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
pred <- predict(bst, test$data)
err <- mean(as.numeric(pred > 0.5) != test$label)
print(paste("test-error=", err))

97
98
#--------------------Save and load models-------------------------
# Save model to binary local file
Guolin Ke's avatar
Guolin Ke committed
99
lgb.save(bst, "lightgbm.model")
100
101

# Load binary model to R
Guolin Ke's avatar
Guolin Ke committed
102
103
bst2 <- lgb.load("lightgbm.model")
pred2 <- predict(bst2, test$data)
104

Guolin Ke's avatar
Guolin Ke committed
105
# pred2 should be identical to pred
106
print(paste("sum(abs(pred2-pred))=", sum(abs(pred2 - pred))))
Guolin Ke's avatar
Guolin Ke committed
107

108
109
110
#--------------------Advanced features ---------------------------
# To use advanced features, we need to put data in lgb.Dataset
dtrain <- lgb.Dataset(data = train$data, label = train$label, free_raw_data = FALSE)
Laurae's avatar
Laurae committed
111
dtest <- lgb.Dataset.create.valid(dtrain, data = test$data, label = test$label)
Guolin Ke's avatar
Guolin Ke committed
112

113
#--------------------Using validation set-------------------------
Guolin Ke's avatar
Guolin Ke committed
114
# valids is a list of lgb.Dataset, each of them is tagged with name
115
116
117
valids <- list(train = dtrain, test = dtest)

# To train with valids, use lgb.train, which contains more advanced features
118
# valids allows us to monitor the evaluation result on all data in the list
Guolin Ke's avatar
Guolin Ke committed
119
print("Train lightgbm using lgb.train with valids")
120
121
bst <- lgb.train(
    data = dtrain
122
    , params = train_params
123
    , nrounds = 2L
124
125
    , valids = valids
)
126
127
128

# We can change evaluation metrics, or use multiple evaluation metrics
print("Train lightgbm using lgb.train with valids, watch logloss and error")
129
130
bst <- lgb.train(
    data = dtrain
131
    , params = train_params
132
    , nrounds = 2L
133
134
135
    , valids = valids
    , eval = c("binary_error", "binary_logloss")
)
Guolin Ke's avatar
Guolin Ke committed
136
137
138

# lgb.Dataset can also be saved using lgb.Dataset.save
lgb.Dataset.save(dtrain, "dtrain.buffer")
139
140

# To load it in, simply call lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
141
dtrain2 <- lgb.Dataset("dtrain.buffer")
142
143
bst <- lgb.train(
    data = dtrain2
144
    , params = train_params
145
    , nrounds = 2L
146
147
    , valids = valids
)
148

149
150
# information can be extracted from lgb.Dataset using get_field()
label <- get_field(dtest, "label")
Guolin Ke's avatar
Guolin Ke committed
151
pred <- predict(bst, test$data)
152
err <- as.numeric(sum(as.integer(pred > 0.5) != label)) / length(label)
Guolin Ke's avatar
Guolin Ke committed
153
print(paste("test-error=", err))