basic_walkthrough.R 4.54 KB
Newer Older
1
2
library(lightgbm)
library(methods)
Guolin Ke's avatar
Guolin Ke committed
3

4
# We load in the agaricus dataset
Guolin Ke's avatar
Guolin Ke committed
5
# In this example, we are aiming to predict whether a mushroom is edible
6
7
data(agaricus.train, package = "lightgbm")
data(agaricus.test, package = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
8
9
train <- agaricus.train
test <- agaricus.test
10
11

# The loaded data is stored in sparseMatrix, and label is a numeric vector in {0,1}
Guolin Ke's avatar
Guolin Ke committed
12
13
14
class(train$label)
class(train$data)

15
16
17
18
19
20
21
22
# Set parameters for model training
train_params <- list(
    num_leaves = 4L
    , learning_rate = 1.0
    , objective = "binary"
    , nthread = 2L
)

23
24
25
26
#--------------------Basic Training using lightgbm----------------
# This is the basic usage of lightgbm you can put matrix in data field
# Note: we are putting in sparse matrix here, lightgbm naturally handles sparse input
# Use sparse matrix when your feature is sparse (e.g. when you are using one-hot encoding vector)
Guolin Ke's avatar
Guolin Ke committed
27
print("Training lightgbm with sparseMatrix")
28
29
bst <- lightgbm(
    data = train$data
30
    , params = train_params
31
    , label = train$label
32
    , nrounds = 2L
33
)
34
35

# Alternatively, you can put in dense matrix, i.e. basic R-matrix
Guolin Ke's avatar
Guolin Ke committed
36
print("Training lightgbm with Matrix")
37
38
bst <- lightgbm(
    data = as.matrix(train$data)
39
    , params = train_params
40
    , label = train$label
41
    , nrounds = 2L
42
)
Guolin Ke's avatar
Guolin Ke committed
43

44
# You can also put in lgb.Dataset object, which stores label, data and other meta datas needed for advanced features
Guolin Ke's avatar
Guolin Ke committed
45
print("Training lightgbm with lgb.Dataset")
46
47
48
49
50
51
dtrain <- lgb.Dataset(
    data = train$data
    , label = train$label
)
bst <- lightgbm(
    data = dtrain
52
    , params = train_params
53
    , nrounds = 2L
54
)
Guolin Ke's avatar
Guolin Ke committed
55
56
57

# Verbose = 0,1,2
print("Train lightgbm with verbose 0, no message")
58
59
bst <- lightgbm(
    data = dtrain
60
    , params = train_params
61
62
    , nrounds = 2L
    , verbose = 0L
63
)
64

Guolin Ke's avatar
Guolin Ke committed
65
print("Train lightgbm with verbose 1, print evaluation metric")
66
67
bst <- lightgbm(
    data = dtrain
68
    , params = train_params
69
70
    , nrounds = 2L
    , verbose = 1L
71
)
72

Guolin Ke's avatar
Guolin Ke committed
73
print("Train lightgbm with verbose 2, also print information about tree")
74
75
bst <- lightgbm(
    data = dtrain
76
    , params = train_params
77
78
    , nrounds = 2L
    , verbose = 2L
79
)
Guolin Ke's avatar
Guolin Ke committed
80

81
82
# You can also specify data as file path to a LibSVM/TCV/CSV format input
# Since we do not have this file with us, the following line is just for illustration
83
84
# bst <- lightgbm(
#     data = "agaricus.train.svm"
85
86
87
#     , num_leaves = 4L
#     , learning_rate = 1.0
#     , nrounds = 2L
88
89
#     , objective = "binary"
# )
Guolin Ke's avatar
Guolin Ke committed
90

91
92
93
#--------------------Basic prediction using lightgbm--------------
# You can do prediction using the following line
# You can put in Matrix, sparseMatrix, or lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
94
95
96
97
pred <- predict(bst, test$data)
err <- mean(as.numeric(pred > 0.5) != test$label)
print(paste("test-error=", err))

98
99
#--------------------Save and load models-------------------------
# Save model to binary local file
Guolin Ke's avatar
Guolin Ke committed
100
lgb.save(bst, "lightgbm.model")
101
102

# Load binary model to R
Guolin Ke's avatar
Guolin Ke committed
103
104
bst2 <- lgb.load("lightgbm.model")
pred2 <- predict(bst2, test$data)
105

Guolin Ke's avatar
Guolin Ke committed
106
# pred2 should be identical to pred
107
print(paste("sum(abs(pred2-pred))=", sum(abs(pred2 - pred))))
Guolin Ke's avatar
Guolin Ke committed
108

109
110
111
#--------------------Advanced features ---------------------------
# To use advanced features, we need to put data in lgb.Dataset
dtrain <- lgb.Dataset(data = train$data, label = train$label, free_raw_data = FALSE)
Laurae's avatar
Laurae committed
112
dtest <- lgb.Dataset.create.valid(dtrain, data = test$data, label = test$label)
Guolin Ke's avatar
Guolin Ke committed
113

114
#--------------------Using validation set-------------------------
Guolin Ke's avatar
Guolin Ke committed
115
# valids is a list of lgb.Dataset, each of them is tagged with name
116
117
118
valids <- list(train = dtrain, test = dtest)

# To train with valids, use lgb.train, which contains more advanced features
119
# valids allows us to monitor the evaluation result on all data in the list
Guolin Ke's avatar
Guolin Ke committed
120
print("Train lightgbm using lgb.train with valids")
121
122
bst <- lgb.train(
    data = dtrain
123
    , params = train_params
124
    , nrounds = 2L
125
126
    , valids = valids
)
127
128
129

# We can change evaluation metrics, or use multiple evaluation metrics
print("Train lightgbm using lgb.train with valids, watch logloss and error")
130
131
bst <- lgb.train(
    data = dtrain
132
    , params = train_params
133
    , nrounds = 2L
134
135
136
    , valids = valids
    , eval = c("binary_error", "binary_logloss")
)
Guolin Ke's avatar
Guolin Ke committed
137
138
139

# lgb.Dataset can also be saved using lgb.Dataset.save
lgb.Dataset.save(dtrain, "dtrain.buffer")
140
141

# To load it in, simply call lgb.Dataset
Guolin Ke's avatar
Guolin Ke committed
142
dtrain2 <- lgb.Dataset("dtrain.buffer")
143
144
bst <- lgb.train(
    data = dtrain2
145
    , params = train_params
146
    , nrounds = 2L
147
148
    , valids = valids
)
149

Guolin Ke's avatar
Guolin Ke committed
150
# information can be extracted from lgb.Dataset using getinfo
151
label <- getinfo(dtest, "label")
Guolin Ke's avatar
Guolin Ke committed
152
pred <- predict(bst, test$data)
153
err <- as.numeric(sum(as.integer(pred > 0.5) != label)) / length(label)
Guolin Ke's avatar
Guolin Ke committed
154
print(paste("test-error=", err))