leaf_stability.R 6.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# We are going to look at how iterating too much might generate observation instability.
# Obviously, we are in a controlled environment, without issues (real rules).
# Do not do this in a real scenario.

# First, we load our libraries
library(lightgbm)
library(ggplot2)

# Second, we load our data
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
data(agaricus.test, package = "lightgbm")
test <- agaricus.test
dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)

# Third, we setup parameters and we train a model
params <- list(objective = "regression", metric = "l2")
valids <- list(test = dtest)
20
21
22
23
24
25
26
27
28
29
30
model <- lgb.train(
    params
    , dtrain
    , 50
    , valids
    , min_data = 1
    , learning_rate = 0.1
    , bagging_fraction = 0.1
    , bagging_freq = 1
    , bagging_seed = 1
)
31
32
33
34
35
36

# We create a data.frame with the following structure:
# X = average leaf of the observation throughout all trees
# Y = prediction probability (clamped to [1e-15, 1-1e-15])
# Z = logloss
# binned = binned quantile of average leaf
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
new_data <- data.frame(
    X = rowMeans(predict(
        model
        , agaricus.test$data
        , predleaf = TRUE
    ))
    , Y = pmin(
        pmax(
            predict(model, agaricus.test$data)
            , 1e-15
        )
        , 1 - 1e-15
    )
)
new_data$Z <- -1 * (agaricus.test$label * log(new_data$Y) + (1 - agaricus.test$label) * log(1 - new_data$Y))
new_data$binned <- .bincode(
    x = new_data$X
    , breaks = quantile(
        x = new_data$X
        , probs = (1:9) / 10
    )
    , right = TRUE
    , include.lowest = TRUE
)
61
62
63
64
65
66
67
68
69
new_data$binned[is.na(new_data$binned)] <- 0
new_data$binned <- as.factor(new_data$binned)

# We can check the binned content
table(new_data$binned)

# We can plot the binned content
# On the second plot, we clearly notice the lower the bin (the lower the leaf value), the higher the loss
# On the third plot, it is smooth!
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
ggplot(
    data = new_data
    , mapping = aes(x = X, y = Y, color = binned)
) + geom_point() +
  theme_bw() +
  labs(title = "Prediction Depth", x = "Leaf Bin", y = "Prediction Probability")
ggplot(
    data = new_data
    , mapping = aes(x = binned, y = Z, fill = binned, group = binned)
) + geom_boxplot() +
  theme_bw() +
  labs(title = "Prediction Depth Spread", x = "Leaf Bin", y = "Logloss")
ggplot(
    data = new_data
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) + geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")
88
89
90


# Now, let's show with other parameters
91
92
93
94
95
96
97
98
model2 <- lgb.train(
    params
    , dtrain
    , 100
    , valids
    , min_data = 1
    , learning_rate = 1
)
99
100

# We create the data structure, but for model2
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
new_data2 <- data.frame(
    X = rowMeans(predict(
        model2
        , agaricus.test$data
        , predleaf = TRUE
    ))
    , Y = pmin(
        pmax(
            predict(
                model2
                , agaricus.test$data
            )
            , 1e-15
        )
      , 1 - 1e-15
     )
)
new_data2$Z <- -1 * (agaricus.test$label * log(new_data2$Y) + (1 - agaricus.test$label) * log(1 - new_data2$Y))
new_data2$binned <- .bincode(
    x = new_data2$X
    , breaks = quantile(
        x = new_data2$X
        , probs = (1:9) / 10
    )
    , right = TRUE
    , include.lowest = TRUE
)
128
129
130
131
132
133
134
135
136
137
new_data2$binned[is.na(new_data2$binned)] <- 0
new_data2$binned <- as.factor(new_data2$binned)

# We can check the binned content
table(new_data2$binned)

# We can plot the binned content
# On the second plot, we clearly notice the lower the bin (the lower the leaf value), the higher the loss
# On the third plot, it is clearly not smooth! We are severely overfitting the data, but the rules are real thus it is not an issue
# However, if the rules were not true, the loss would explode.
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
ggplot(
    data = new_data2
    , mapping = aes(x = X, y = Y, color = binned)
) + geom_point() +
  theme_bw() +
  labs(title = "Prediction Depth", x = "Leaf Bin", y = "Prediction Probability")
ggplot(
    data = new_data2
    , mapping = aes(x = binned, y = Z, fill = binned, group = binned)
) + geom_boxplot() +
  theme_bw() +
  labs(title = "Prediction Depth Spread", x = "Leaf Bin", y = "Logloss")
ggplot(
    data = new_data2
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) + geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")
156
157
158


# Now, try with very severe overfitting
159
160
161
162
163
164
165
166
model3 <- lgb.train(
    params
    , dtrain
    , 1000
    , valids
    , min_data = 1
    , learning_rate = 1
)
167
168

# We create the data structure, but for model3
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
new_data3 <- data.frame(
    X = rowMeans(predict(
        model3
        , agaricus.test$data
        , predleaf = TRUE
    ))
    , Y = pmin(
        pmax(
            predict(
                model3
                , agaricus.test$data
            )
            , 1e-15
        )
        , 1 - 1e-15
    )
)
new_data3$Z <- -1 * (agaricus.test$label * log(new_data3$Y) + (1 - agaricus.test$label) * log(1 - new_data3$Y))
new_data3$binned <- .bincode(
    x = new_data3$X
    , breaks = quantile(
        x = new_data3$X
        , probs = (1:9) / 10
    )
    , right = TRUE
    , include.lowest = TRUE
)
196
197
198
199
200
201
202
new_data3$binned[is.na(new_data3$binned)] <- 0
new_data3$binned <- as.factor(new_data3$binned)

# We can check the binned content
table(new_data3$binned)

# We can plot the binned content
203
204
# On the third plot, it is clearly not smooth! We are severely overfitting the data, but the rules
# are real thus it is not an issue.
205
# However, if the rules were not true, the loss would explode. See the sudden spikes?
206
207
208
209
210
211
212
ggplot(
    data = new_data3
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) +
  geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")
213
214

# Compare with our second model, the difference is severe. This is smooth.
215
216
217
218
219
220
ggplot(
    data = new_data2
    , mapping = aes(x = Y, y = ..count.., fill = binned)
) + geom_density(position = "fill") +
  theme_bw() +
  labs(title = "Depth Density", x = "Prediction Probability", y = "Bin Density")