utils.R 7.59 KB
Newer Older
1
2
3
lgb.is.Booster <- function(x) {
  lgb.check.r6.class(x, "lgb.Booster") # Checking if it is of class lgb.Booster or not
}
4

5
6
7
lgb.is.Dataset <- function(x) {
  lgb.check.r6.class(x, "lgb.Dataset") # Checking if it is of class lgb.Dataset or not
}
8

9
lgb.is.null.handle <- function(x) {
10
  is.null(x) || is.na(x)
11
}
Guolin Ke's avatar
Guolin Ke committed
12
13

lgb.encode.char <- function(arr, len) {
14

15
  if (!is.raw(arr)) {
16
    stop("lgb.encode.char: Can only encode from raw type") # Not an object of type raw
Guolin Ke's avatar
Guolin Ke committed
17
  }
18
  rawToChar(arr[seq_len(len)]) # Return the conversion of raw type to character type
19

Guolin Ke's avatar
Guolin Ke committed
20
21
}

22
23
# [description] Raise an error. Before raising that error, check for any error message
#               stored in a buffer on the C++ side.
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
lgb.last_error <- function() {
  # Perform text error buffering
  buf_len <- 200L
  act_len <- 0L
  err_msg <- raw(buf_len)
  err_msg <- .Call(
    "LGBM_GetLastError_R"
    , buf_len
    , act_len
    , err_msg
    , PACKAGE = "lib_lightgbm"
  )

  # Check error buffer
  if (act_len > buf_len) {
    buf_len <- act_len
    err_msg <- raw(buf_len)
    err_msg <- .Call(
      "LGBM_GetLastError_R"
      , buf_len
      , act_len
      , err_msg
      , PACKAGE = "lib_lightgbm"
    )
  }

  # Return error
  stop("api error: ", lgb.encode.char(err_msg, act_len))
}

Guolin Ke's avatar
Guolin Ke committed
54
lgb.call <- function(fun_name, ret, ...) {
55
  # Set call state to a zero value
56
  call_state <- 0L
57

58
  # Check for a ret call
Guolin Ke's avatar
Guolin Ke committed
59
  if (!is.null(ret)) {
60
61
62
63
64
65
66
    call_state <- .Call(
      fun_name
      , ...
      , ret
      , call_state
      , PACKAGE = "lib_lightgbm"
    )
Guolin Ke's avatar
Guolin Ke committed
67
  } else {
68
69
70
71
72
73
    call_state <- .Call(
      fun_name
      , ...
      , call_state
      , PACKAGE = "lib_lightgbm"
    )
Guolin Ke's avatar
Guolin Ke committed
74
  }
Guolin Ke's avatar
Guolin Ke committed
75
  call_state <- as.integer(call_state)
76
  # Check for call state value post call
77
  if (call_state != 0L) {
78
    lgb.last_error()
Guolin Ke's avatar
Guolin Ke committed
79
  }
Guolin Ke's avatar
Guolin Ke committed
80

81
  return(ret)
82

Guolin Ke's avatar
Guolin Ke committed
83
84
85
}

lgb.call.return.str <- function(fun_name, ...) {
86

87
  # Create buffer
88
  buf_len <- as.integer(1024L * 1024L)
89
  act_len <- 0L
Guolin Ke's avatar
Guolin Ke committed
90
  buf <- raw(buf_len)
91

92
  # Call buffer
Guolin Ke's avatar
Guolin Ke committed
93
  buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
94

95
  # Check for buffer content
Guolin Ke's avatar
Guolin Ke committed
96
97
  if (act_len > buf_len) {
    buf_len <- act_len
98
99
    buf <- raw(buf_len)
    buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
Guolin Ke's avatar
Guolin Ke committed
100
  }
101

102
103
  # Return encoded character
  return(lgb.encode.char(buf, act_len))
104

Guolin Ke's avatar
Guolin Ke committed
105
106
107
}

lgb.params2str <- function(params, ...) {
108

109
  # Check for a list as input
110
  if (!identical(class(params), "list")) {
111
112
    stop("params must be a list")
  }
113

114
  # Split parameter names
Guolin Ke's avatar
Guolin Ke committed
115
  names(params) <- gsub("\\.", "_", names(params))
116

117
  # Merge parameters from the params and the dots-expansion
Guolin Ke's avatar
Guolin Ke committed
118
119
  dot_params <- list(...)
  names(dot_params) <- gsub("\\.", "_", names(dot_params))
120

121
  # Check for identical parameters
122
  if (length(intersect(names(params), names(dot_params))) > 0L) {
123
124
125
126
127
128
129
    stop(
      "Same parameters in "
      , sQuote("params")
      , " and in the call are not allowed. Please check your "
      , sQuote("params")
      , " list"
    )
130
  }
131

132
  # Merge parameters
Guolin Ke's avatar
Guolin Ke committed
133
  params <- c(params, dot_params)
134

135
136
  # Setup temporary variable
  ret <- list()
137

138
  # Perform key value join
Guolin Ke's avatar
Guolin Ke committed
139
  for (key in names(params)) {
140

141
142
143
144
145
146
147
148
149
150
151
    # If a parameter has multiple values, join those values together with commas.
    # trimws() is necessary because format() will pad to make strings the same width
    val <- paste0(
      trimws(
        format(
          x = params[[key]]
          , scientific = FALSE
        )
      )
      , collapse = ","
    )
152
    if (nchar(val) <= 0L) next # Skip join
153

154
    # Join key value
Guolin Ke's avatar
Guolin Ke committed
155
    pair <- paste0(c(key, val), collapse = "=")
156
    ret <- c(ret, pair)
157

Guolin Ke's avatar
Guolin Ke committed
158
  }
159

160
  # Check ret length
161
  if (length(ret) == 0L) {
162
    return(lgb.c_str(""))
Guolin Ke's avatar
Guolin Ke committed
163
  }
164

165
166
167
  # Return string separated by a space per element
  return(lgb.c_str(paste0(ret, collapse = " ")))

Guolin Ke's avatar
Guolin Ke committed
168
169
}

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
lgb.check_interaction_constraints <- function(params, column_names) {

  # Convert interaction constraints to feature numbers
  string_constraints <- list()

  if (!is.null(params[["interaction_constraints"]])) {

    # validation
    if (!methods::is(params[["interaction_constraints"]], "list")) {
        stop("interaction_constraints must be a list")
    }
    if (!all(sapply(params[["interaction_constraints"]], function(x) {is.character(x) || is.numeric(x)}))) {
        stop("every element in interaction_constraints must be a character vector or numeric vector")
    }

    for (constraint in params[["interaction_constraints"]]) {

      # Check for character name
      if (is.character(constraint)) {

          constraint_indices <- as.integer(match(constraint, column_names) - 1L)

          # Provided indices, but some indices are not existing?
          if (sum(is.na(constraint_indices)) > 0L) {
            stop(
              "supplied an unknown feature in interaction_constraints "
              , sQuote(constraint[is.na(constraint_indices)])
            )
          }

        } else {

          # Check that constraint indices are at most number of features
          if (max(constraint) > length(column_names)) {
            stop(
              "supplied a too large value in interaction_constraints: "
              , max(constraint)
              , " but only "
              , length(column_names)
              , " features"
            )
          }

          # Store indices as [0, n-1] indexed instead of [1, n] indexed
          constraint_indices <- as.integer(constraint - 1L)

        }

        # Convert constraint to string
        constraint_string <- paste0("[", paste0(constraint_indices, collapse = ","), "]")
        string_constraints <- append(string_constraints, constraint_string)
    }

  }

  return(string_constraints)

}

Guolin Ke's avatar
Guolin Ke committed
229
lgb.c_str <- function(x) {
230

231
  # Perform character to raw conversion
Guolin Ke's avatar
Guolin Ke committed
232
  ret <- charToRaw(as.character(x))
233
  ret <- c(ret, as.raw(0L))
234
  ret
235

Guolin Ke's avatar
Guolin Ke committed
236
237
238
}

lgb.check.r6.class <- function(object, name) {
239

240
241
  # Check for non-existence of R6 class or named class
  all(c("R6", name) %in% class(object))
242

Guolin Ke's avatar
Guolin Ke committed
243
244
245
}

lgb.check.obj <- function(params, obj) {
246

247
  # List known objectives in a vector
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
  OBJECTIVES <- c(
    "regression"
    , "regression_l1"
    , "regression_l2"
    , "mean_squared_error"
    , "mse"
    , "l2_root"
    , "root_mean_squared_error"
    , "rmse"
    , "mean_absolute_error"
    , "mae"
    , "quantile"
    , "huber"
    , "fair"
    , "poisson"
    , "binary"
    , "lambdarank"
    , "multiclass"
    , "softmax"
    , "multiclassova"
    , "multiclass_ova"
    , "ova"
    , "ovr"
    , "xentropy"
    , "cross_entropy"
    , "xentlambda"
    , "cross_entropy_lambda"
    , "mean_absolute_percentage_error"
    , "mape"
    , "gamma"
    , "tweedie"
279
280
281
282
283
    , "rank_xendcg"
    , "xendcg"
    , "xe_ndcg"
    , "xe_ndcg_mart"
    , "xendcg_mart"
284
  )
285

286
  # Check whether the objective is empty or not, and take it from params if needed
287
288
289
  if (!is.null(obj)) {
    params$objective <- obj
  }
290

291
  # Check whether the objective is a character
292
  if (is.character(params$objective)) {
293

294
    # If the objective is a character, check if it is a known objective
295
    if (!(params$objective %in% OBJECTIVES)) {
296

297
      # Interrupt on unknown objective name
298
      stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
299

Guolin Ke's avatar
Guolin Ke committed
300
    }
301

302
  } else if (!is.function(params$objective)) {
303

304
    # If objective is not a character nor a function, then stop
305
    stop("lgb.check.obj: objective should be a character or a function")
306

Guolin Ke's avatar
Guolin Ke committed
307
  }
308

309
310
  # Return parameters
  return(params)
311

Guolin Ke's avatar
Guolin Ke committed
312
313
314
}

lgb.check.eval <- function(params, eval) {
315

316
317
318
319
  # Check if metric is null, if yes put a list instead
  if (is.null(params$metric)) {
    params$metric <- list()
  }
320

321
  # If 'eval' is a list or character vector, store it in 'metric'
322
  if (is.character(eval) || identical(class(eval), "list")) {
323
    params$metric <- append(params$metric, eval)
324
  }
325

326
  return(params)
Guolin Ke's avatar
Guolin Ke committed
327
}