utils.R 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
lgb.is.Booster <- function(x) { lgb.check.r6.class(x, "lgb.Booster") }

lgb.is.Dataset <- function(x) { lgb.check.r6.class(x, "lgb.Dataset") }

# use 64bit data to store address
lgb.new.handle <- function() { 0.0 }

lgb.is.null.handle <- function(x) { is.null(x) || x == 0 }
Guolin Ke's avatar
Guolin Ke committed
9
10

lgb.encode.char <- function(arr, len) {
11
12
  if (!is.raw(arr)) {
    stop("lgb.encode.char: Can only encode from raw type")
Guolin Ke's avatar
Guolin Ke committed
13
  }
14
  rawToChar(arr[seq_len(len)])
Guolin Ke's avatar
Guolin Ke committed
15
16
17
}

lgb.call <- function(fun_name, ret, ...) {
18
  call_state <- 0L
Guolin Ke's avatar
Guolin Ke committed
19
  if (!is.null(ret)) {
20
    call_state <- .Call(fun_name, ..., ret, call_state, PACKAGE = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
21
  } else {
22
    call_state <- .Call(fun_name, ..., call_state, PACKAGE = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
23
  }
24
25
26
  if (call_state != 0L) {
    buf_len <- 200L
    act_len <- 0L
Guolin Ke's avatar
Guolin Ke committed
27
    err_msg <- raw(buf_len)
28
    err_msg <- .Call("LGBM_GetLastError_R", buf_len, act_len, err_msg, PACKAGE = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
29
30
31
    if (act_len > buf_len) {
      buf_len <- act_len
      err_msg <- raw(buf_len)
32
33
34
35
36
      err_msg <- .Call("LGBM_GetLastError_R",
                        buf_len,
                        act_len,
                        err_msg,
                        PACKAGE = "lightgbm")
Guolin Ke's avatar
Guolin Ke committed
37
38
39
    }
    stop(paste0("api error: ", lgb.encode.char(err_msg, act_len)))
  }
40
  ret
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
45
}


lgb.call.return.str <- function(fun_name, ...) {
  buf_len <- as.integer(1024 * 1024)
46
  act_len <- 0L
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
  buf <- raw(buf_len)
  buf <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
  if (act_len > buf_len) {
    buf_len <- act_len
51
52
    buf     <- raw(buf_len)
    buf     <- lgb.call(fun_name, ret = buf, ..., buf_len, act_len)
Guolin Ke's avatar
Guolin Ke committed
53
  }
54
  lgb.encode.char(buf, act_len)
Guolin Ke's avatar
Guolin Ke committed
55
56
57
}

lgb.params2str <- function(params, ...) {
58
  if (!is.list(params)) { stop("params must be a list") }
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
64
65
  names(params) <- gsub("\\.", "_", names(params))
  # merge parameters from the params and the dots-expansion
  dot_params <- list(...)
  names(dot_params) <- gsub("\\.", "_", names(dot_params))
  if (length(intersect(names(params),
                       names(dot_params))) > 0)
    stop(
66
      "Same parameters in ", sQuote("params"), " and in the call are not allowed. Please check your ", sQuote("params"), " list"
Guolin Ke's avatar
Guolin Ke committed
67
68
    )
  params <- c(params, dot_params)
69
  ret    <- list()
Guolin Ke's avatar
Guolin Ke committed
70
71
72
  for (key in names(params)) {
    # join multi value first
    val <- paste0(params[[key]], collapse = ",")
73
    if (nchar(val) <= 0) next
Guolin Ke's avatar
Guolin Ke committed
74
75
    # join key value
    pair <- paste0(c(key, val), collapse = "=")
76
    ret  <- c(ret, pair)
Guolin Ke's avatar
Guolin Ke committed
77
78
  }
  if (length(ret) == 0) {
79
80
81
    lgb.c_str("")
  } else {
    lgb.c_str(paste0(ret, collapse = " "))
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
  }
}

lgb.c_str <- function(x) {
  ret <- charToRaw(as.character(x))
  ret <- c(ret, as.raw(0))
88
  ret
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
94
95
96
97
}

lgb.check.r6.class <- function(object, name) {
  if (!("R6" %in% class(object))) {
    return(FALSE)
  }
  if (!(name %in% class(object))) {
    return(FALSE)
  }
98
  TRUE
Guolin Ke's avatar
Guolin Ke committed
99
100
}

101
lgb.check.params <- function(params) {
Guolin Ke's avatar
Guolin Ke committed
102
  # To-do
103
  params
Guolin Ke's avatar
Guolin Ke committed
104
105
106
}

lgb.check.obj <- function(params, obj) {
107
108
109
110
111
  OBJECTIVES <- c("regression", "binary", "multiclass", "lambdarank")
  if (!is.null(obj)) { params$objective <- obj }
  if (is.character(params$objective)) {
    if (!(params$objective %in% OBJECTIVES)) {
      stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")
Guolin Ke's avatar
Guolin Ke committed
112
    }
113
114
  } else if (!is.function(params$objective)) {
    stop("lgb.check.obj: objective should be a character or a function")
Guolin Ke's avatar
Guolin Ke committed
115
  }
116
  params
Guolin Ke's avatar
Guolin Ke committed
117
118
119
}

lgb.check.eval <- function(params, eval) {
120
121
  if (is.null(params$metric)) { params$metric <- list() }
  if (!is.null(eval)) {
Guolin Ke's avatar
Guolin Ke committed
122
    # append metric
123
    if (is.character(eval) || is.list(eval)) {
Guolin Ke's avatar
Guolin Ke committed
124
125
126
      params$metric <- append(params$metric, eval)
    }
  }
127
128
  if (!is.function(eval)) {
    if (length(params$metric) == 0) {
Guolin Ke's avatar
Guolin Ke committed
129
      # add default metric
130
131
132
133
134
135
136
137
      params$metric <- switch(
        params$objective,
        regression = "l2",
        binary     = "binary_logloss",
        multiclass = "multi_logloss",
        lambdarank = "ndcg",
        stop("lgb.check.eval: No default metric available for objective ", sQuote(params$objective))
      )
Guolin Ke's avatar
Guolin Ke committed
138
139
    }
  }
140
  params
Guolin Ke's avatar
Guolin Ke committed
141
}