utils.R 8.06 KB
Newer Older
1
lgb.is.Booster <- function(x) {
2
  lgb.check.r6.class(x, "lgb.Booster")
3
}
4

5
lgb.is.Dataset <- function(x) {
6
  lgb.check.r6.class(x, "lgb.Dataset")
7
}
8

Guolin Ke's avatar
Guolin Ke committed
9
10
11
12
13
14
15
16
lgb.null.handle <- function() {
  if (.Machine$sizeof.pointer == 8L) {
    return(NA_real_)
  } else {
    return(NA_integer_)
  }
}

17
lgb.is.null.handle <- function(x) {
18
  is.null(x) || is.na(x)
19
}
Guolin Ke's avatar
Guolin Ke committed
20
21

lgb.encode.char <- function(arr, len) {
22
  if (!is.raw(arr)) {
23
    stop("lgb.encode.char: Can only encode from raw type")
Guolin Ke's avatar
Guolin Ke committed
24
  }
25
  return(rawToChar(arr[seq_len(len)]))
Guolin Ke's avatar
Guolin Ke committed
26
27
}

28
29
# [description] Raise an error. Before raising that error, check for any error message
#               stored in a buffer on the C++ side.
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
lgb.last_error <- function() {
  # Perform text error buffering
  buf_len <- 200L
  act_len <- 0L
  err_msg <- raw(buf_len)
  err_msg <- .Call(
    "LGBM_GetLastError_R"
    , buf_len
    , act_len
    , err_msg
    , PACKAGE = "lib_lightgbm"
  )

  # Check error buffer
  if (act_len > buf_len) {
    buf_len <- act_len
    err_msg <- raw(buf_len)
    err_msg <- .Call(
      "LGBM_GetLastError_R"
      , buf_len
      , act_len
      , err_msg
      , PACKAGE = "lib_lightgbm"
    )
  }

  stop("api error: ", lgb.encode.char(err_msg, act_len))
}

Guolin Ke's avatar
Guolin Ke committed
59
lgb.call <- function(fun_name, ret, ...) {
60
  # Set call state to a zero value
61
  call_state <- 0L
62

63
  # Check for a ret call
Guolin Ke's avatar
Guolin Ke committed
64
  if (!is.null(ret)) {
65
66
67
68
69
70
71
    call_state <- .Call(
      fun_name
      , ...
      , ret
      , call_state
      , PACKAGE = "lib_lightgbm"
    )
Guolin Ke's avatar
Guolin Ke committed
72
  } else {
73
74
75
76
77
78
    call_state <- .Call(
      fun_name
      , ...
      , call_state
      , PACKAGE = "lib_lightgbm"
    )
Guolin Ke's avatar
Guolin Ke committed
79
  }
Guolin Ke's avatar
Guolin Ke committed
80
  call_state <- as.integer(call_state)
81
  # Check for call state value post call
82
  if (call_state != 0L) {
83
    lgb.last_error()
Guolin Ke's avatar
Guolin Ke committed
84
  }
Guolin Ke's avatar
Guolin Ke committed
85

86
  return(ret)
87

Guolin Ke's avatar
Guolin Ke committed
88
89
90
}

lgb.call.return.str <- function(fun_name, ...) {
91

92
  # Create buffer
93
  buf_len <- as.integer(1024L * 1024L)
94
  act_len <- 0L
Guolin Ke's avatar
Guolin Ke committed
95
  buf <- raw(buf_len)
96

97
  # Call buffer
Guolin Ke's avatar
Guolin Ke committed
98
  buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
99

100
  # Check for buffer content
Guolin Ke's avatar
Guolin Ke committed
101
102
  if (act_len > buf_len) {
    buf_len <- act_len
103
104
    buf <- raw(buf_len)
    buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
Guolin Ke's avatar
Guolin Ke committed
105
  }
106

107
  return(lgb.encode.char(buf, act_len))
108

Guolin Ke's avatar
Guolin Ke committed
109
110
111
}

lgb.params2str <- function(params, ...) {
112

113
  # Check for a list as input
114
  if (!identical(class(params), "list")) {
115
116
    stop("params must be a list")
  }
117

118
  # Split parameter names
Guolin Ke's avatar
Guolin Ke committed
119
  names(params) <- gsub("\\.", "_", names(params))
120

121
  # Merge parameters from the params and the dots-expansion
Guolin Ke's avatar
Guolin Ke committed
122
123
  dot_params <- list(...)
  names(dot_params) <- gsub("\\.", "_", names(dot_params))
124

125
  # Check for identical parameters
126
  if (length(intersect(names(params), names(dot_params))) > 0L) {
127
128
129
130
131
132
133
    stop(
      "Same parameters in "
      , sQuote("params")
      , " and in the call are not allowed. Please check your "
      , sQuote("params")
      , " list"
    )
134
  }
135

136
  # Merge parameters
Guolin Ke's avatar
Guolin Ke committed
137
  params <- c(params, dot_params)
138

139
140
  # Setup temporary variable
  ret <- list()
141

142
  # Perform key value join
Guolin Ke's avatar
Guolin Ke committed
143
  for (key in names(params)) {
144

145
146
147
148
149
150
151
152
153
154
155
    # If a parameter has multiple values, join those values together with commas.
    # trimws() is necessary because format() will pad to make strings the same width
    val <- paste0(
      trimws(
        format(
          x = params[[key]]
          , scientific = FALSE
        )
      )
      , collapse = ","
    )
156
    if (nchar(val) <= 0L) next # Skip join
157

158
    # Join key value
Guolin Ke's avatar
Guolin Ke committed
159
    pair <- paste0(c(key, val), collapse = "=")
160
    ret <- c(ret, pair)
161

Guolin Ke's avatar
Guolin Ke committed
162
  }
163

164
  # Check ret length
165
  if (length(ret) == 0L) {
166
    return(lgb.c_str(""))
Guolin Ke's avatar
Guolin Ke committed
167
  }
168

169
170
  return(lgb.c_str(paste0(ret, collapse = " ")))

Guolin Ke's avatar
Guolin Ke committed
171
172
}

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
lgb.check_interaction_constraints <- function(params, column_names) {

  # Convert interaction constraints to feature numbers
  string_constraints <- list()

  if (!is.null(params[["interaction_constraints"]])) {

    if (!methods::is(params[["interaction_constraints"]], "list")) {
        stop("interaction_constraints must be a list")
    }
    if (!all(sapply(params[["interaction_constraints"]], function(x) {is.character(x) || is.numeric(x)}))) {
        stop("every element in interaction_constraints must be a character vector or numeric vector")
    }

    for (constraint in params[["interaction_constraints"]]) {

      # Check for character name
      if (is.character(constraint)) {

          constraint_indices <- as.integer(match(constraint, column_names) - 1L)

          # Provided indices, but some indices are not existing?
          if (sum(is.na(constraint_indices)) > 0L) {
            stop(
              "supplied an unknown feature in interaction_constraints "
              , sQuote(constraint[is.na(constraint_indices)])
            )
          }

        } else {

          # Check that constraint indices are at most number of features
          if (max(constraint) > length(column_names)) {
            stop(
              "supplied a too large value in interaction_constraints: "
              , max(constraint)
              , " but only "
              , length(column_names)
              , " features"
            )
          }

          # Store indices as [0, n-1] indexed instead of [1, n] indexed
          constraint_indices <- as.integer(constraint - 1L)

        }

        # Convert constraint to string
        constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]")
        string_constraints <- append(string_constraints, constraint_string)
    }

  }

  return(string_constraints)

}

Guolin Ke's avatar
Guolin Ke committed
231
lgb.c_str <- function(x) {
232

Guolin Ke's avatar
Guolin Ke committed
233
  ret <- charToRaw(as.character(x))
234
  ret <- c(ret, as.raw(0L))
235
  ret
236

Guolin Ke's avatar
Guolin Ke committed
237
238
239
}

lgb.check.r6.class <- function(object, name) {
240

241
242
  # Check for non-existence of R6 class or named class
  all(c("R6", name) %in% class(object))
243

Guolin Ke's avatar
Guolin Ke committed
244
245
246
}

lgb.check.obj <- function(params, obj) {
247

248
  # List known objectives in a vector
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
  OBJECTIVES <- c(
    "regression"
    , "regression_l1"
    , "regression_l2"
    , "mean_squared_error"
    , "mse"
    , "l2_root"
    , "root_mean_squared_error"
    , "rmse"
    , "mean_absolute_error"
    , "mae"
    , "quantile"
    , "huber"
    , "fair"
    , "poisson"
    , "binary"
    , "lambdarank"
    , "multiclass"
    , "softmax"
    , "multiclassova"
    , "multiclass_ova"
    , "ova"
    , "ovr"
    , "xentropy"
    , "cross_entropy"
    , "xentlambda"
    , "cross_entropy_lambda"
    , "mean_absolute_percentage_error"
    , "mape"
    , "gamma"
    , "tweedie"
280
281
282
283
284
    , "rank_xendcg"
    , "xendcg"
    , "xe_ndcg"
    , "xe_ndcg_mart"
    , "xendcg_mart"
285
  )
286

287
  # Check whether the objective is empty or not, and take it from params if needed
288
289
290
  if (!is.null(obj)) {
    params$objective <- obj
  }
291

292
  # Check whether the objective is a character
293
  if (is.character(params$objective)) {
294

295
    # If the objective is a character, check if it is a known objective
296
    if (!(params$objective %in% OBJECTIVES)) {
297

298
      stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
299

Guolin Ke's avatar
Guolin Ke committed
300
    }
301

302
  } else if (!is.function(params$objective)) {
303

304
    stop("lgb.check.obj: objective should be a character or a function")
305

Guolin Ke's avatar
Guolin Ke committed
306
  }
307

308
  return(params)
309

Guolin Ke's avatar
Guolin Ke committed
310
311
}

312
# [description]
313
314
315
316
#     Take any character values from eval and store them in params$metric.
#     This has to account for the fact that `eval` could be a character vector,
#     a function, a list of functions, or a list with a mix of strings and
#     functions
Guolin Ke's avatar
Guolin Ke committed
317
lgb.check.eval <- function(params, eval) {
318

319
320
  if (is.null(params$metric)) {
    params$metric <- list()
321
322
  } else if (is.character(params$metric)) {
    params$metric <- as.list(params$metric)
323
  }
324

325
326
327
328
329
330
331
332
333
  # if 'eval' is a character vector or list, find the character
  # elements and add them to 'metric'
  if (!is.function(eval)) {
    for (i in seq_along(eval)) {
      element <- eval[[i]]
      if (is.character(element)) {
        params$metric <- append(params$metric, element)
      }
    }
334
  }
335

336
337
338
339
340
341
342
343
344
  # If more than one character metric was given, then "None" should
  # not be included
  if (length(params$metric) > 1L) {
    params$metric <- Filter(
        f = function(metric) {
          !(metric %in% .NO_METRIC_STRINGS())
        }
        , x = params$metric
    )
345
346
  }

347
348
349
  # duplicate metrics should be filtered out
  params$metric <- as.list(unique(unlist(params$metric)))

350
  return(params)
Guolin Ke's avatar
Guolin Ke committed
351
}