lightgbm_R.cpp 22.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
25
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
31
32
33
34
  catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
  catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
  catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
  return R_NilValue;
Guolin Ke's avatar
Guolin Ke committed
35
36
37

#define CHECK_CALL(x) \
  if ((x) != 0) { \
38
    Rf_error(LGBM_GetLastError()); \
39
    return R_NilValue; \
Guolin Ke's avatar
Guolin Ke committed
40
41
  }

42
43
44
using LightGBM::Common::Join;
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
45

46
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, SEXP actual_len, size_t str_len) {
Guolin Ke's avatar
Guolin Ke committed
47
  if (str_len > INT32_MAX) {
48
    Log::Fatal("Don't support large string in R-package");
Guolin Ke's avatar
Guolin Ke committed
49
  }
50
  INTEGER(actual_len)[0] = static_cast<int>(str_len);
51
  if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
52
53
    return dest;
  }
Guolin Ke's avatar
Guolin Ke committed
54
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
55
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
56
57
58
  return dest;
}

59
60
61
62
63
64
SEXP LGBM_GetLastError_R() {
  SEXP out;
  out = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(out, 0, Rf_mkChar(LGBM_GetLastError()));
  UNPROTECT(1);
  return out;
Guolin Ke's avatar
Guolin Ke committed
65
66
}

67
68
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
Guolin Ke's avatar
Guolin Ke committed
69
  LGBM_SE reference,
70
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
71
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
72
  DatasetHandle handle = nullptr;
73
  CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)),
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

79
80
81
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
82
83
84
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
85
  SEXP parameters,
Guolin Ke's avatar
Guolin Ke committed
86
  LGBM_SE reference,
87
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
88
  R_API_BEGIN();
89
90
91
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
92

93
94
95
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
Guolin Ke's avatar
Guolin Ke committed
96
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
97
98
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
99
    nrow, CHAR(Rf_asChar(parameters)), R_GET_PTR(reference), &handle));
Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
  R_SET_PTR(out, handle);
  R_API_END();
}

104
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
105
106
  SEXP num_row,
  SEXP num_col,
107
  SEXP parameters,
Guolin Ke's avatar
Guolin Ke committed
108
  LGBM_SE reference,
109
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
110
  R_API_BEGIN();
111
112
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
113
  double* p_mat = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
114
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
115
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
116
    CHAR(Rf_asChar(parameters)), R_GET_PTR(reference), &handle));
Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
  R_SET_PTR(out, handle);
  R_API_END();
}

121
SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
122
  SEXP used_row_indices,
123
  SEXP len_used_row_indices,
124
  SEXP parameters,
125
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
126
  R_API_BEGIN();
127
  int len = Rf_asInteger(len_used_row_indices);
Guolin Ke's avatar
Guolin Ke committed
128
129
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
Guolin Ke's avatar
Guolin Ke committed
130
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
131
  for (int i = 0; i < len; ++i) {
132
    idxvec[i] = INTEGER(used_row_indices)[i] - 1;
Guolin Ke's avatar
Guolin Ke committed
133
  }
Guolin Ke's avatar
Guolin Ke committed
134
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
135
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
136
    idxvec.data(), len, CHAR(Rf_asChar(parameters)),
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

142
SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
143
  SEXP feature_names) {
Guolin Ke's avatar
Guolin Ke committed
144
  R_API_BEGIN();
145
  auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
153
154
155
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

156
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
157
  SEXP buf_len,
158
  SEXP actual_len,
159
  LGBM_SE feature_names) {
Guolin Ke's avatar
Guolin Ke committed
160
161
162
  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
163
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
164
165
166
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
167
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
168
169
170
    ptr_names[i] = names[i].data();
  }
  int out_len;
171
172
173
174
175
176
177
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
178
  CHECK_EQ(len, out_len);
179
  CHECK_GE(reserved_string_size, required_string_size);
180
  auto merge_str = Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
181
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
182
183
184
  R_API_END();
}

185
SEXP LGBM_DatasetSaveBinary_R(LGBM_SE handle,
186
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
187
188
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
189
    CHAR(Rf_asChar(filename))));
Guolin Ke's avatar
Guolin Ke committed
190
191
192
  R_API_END();
}

193
SEXP LGBM_DatasetFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
199
200
201
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

202
SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
203
  SEXP field_name,
204
  SEXP field_data,
205
  SEXP num_element) {
Guolin Ke's avatar
Guolin Ke committed
206
  R_API_BEGIN();
207
  int len = static_cast<int>(Rf_asInteger(num_element));
208
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
209
210
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
211
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
212
    for (int i = 0; i < len; ++i) {
213
      vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
214
215
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
216
  } else if (!strcmp("init_score", name)) {
217
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
218
219
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
220
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
221
    for (int i = 0; i < len; ++i) {
222
      vec[i] = static_cast<float>(REAL(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
227
228
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

229
SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
230
  SEXP field_name,
231
  SEXP field_data) {
Guolin Ke's avatar
Guolin Ke committed
232
  R_API_BEGIN();
233
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
234
235
236
237
238
239
240
241
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
242
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
243
    for (int i = 0; i < out_len - 1; ++i) {
244
      INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
245
    }
Guolin Ke's avatar
Guolin Ke committed
246
247
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
248
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
249
    for (int i = 0; i < out_len; ++i) {
250
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
251
    }
Guolin Ke's avatar
Guolin Ke committed
252
253
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
254
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
255
    for (int i = 0; i < out_len; ++i) {
256
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
257
258
259
260
261
    }
  }
  R_API_END();
}

262
SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
263
  SEXP field_name,
264
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
265
  R_API_BEGIN();
266
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
271
272
273
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
274
  INTEGER(out)[0] = static_cast<int>(out_len);
Guolin Ke's avatar
Guolin Ke committed
275
276
277
  R_API_END();
}

278
279
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
280
  R_API_BEGIN();
281
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params))));
282
283
284
  R_API_END();
}

285
SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
286
287
288
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
289
  INTEGER(out)[0] = static_cast<int>(nrow);
Guolin Ke's avatar
Guolin Ke committed
290
291
292
  R_API_END();
}

293
SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
294
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
295
296
297
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
298
  INTEGER(out)[0] = static_cast<int>(nfeature);
Guolin Ke's avatar
Guolin Ke committed
299
300
301
302
303
  R_API_END();
}

// --- start Booster interfaces

304
SEXP LGBM_BoosterFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308
309
310
311
312
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

313
SEXP LGBM_BoosterCreate_R(LGBM_SE train_data,
314
  SEXP parameters,
315
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
316
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
317
  BoosterHandle handle = nullptr;
318
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), CHAR(Rf_asChar(parameters)), &handle));
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
  R_SET_PTR(out, handle);
  R_API_END();
}

323
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename,
324
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
325
326
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
327
  BoosterHandle handle = nullptr;
328
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle));
Guolin Ke's avatar
Guolin Ke committed
329
330
331
332
  R_SET_PTR(out, handle);
  R_API_END();
}

333
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str,
334
  LGBM_SE out) {
335
336
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
337
  BoosterHandle handle = nullptr;
338
  CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle));
339
340
341
342
  R_SET_PTR(out, handle);
  R_API_END();
}

343
344
SEXP LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle) {
Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
349
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

350
351
SEXP LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data) {
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
356
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

357
358
SEXP LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data) {
Guolin Ke's avatar
Guolin Ke committed
359
360
361
362
363
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

364
SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle,
365
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
366
  R_API_BEGIN();
367
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), CHAR(Rf_asChar(parameters))));
Guolin Ke's avatar
Guolin Ke committed
368
369
370
  R_API_END();
}

371
SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
372
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
373
374
375
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
376
  INTEGER(out)[0] = static_cast<int>(num_class);
Guolin Ke's avatar
Guolin Ke committed
377
378
379
  R_API_END();
}

380
SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
381
382
383
384
385
386
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

387
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
388
389
  SEXP grad,
  SEXP hess,
390
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
391
392
  int is_finished = 0;
  R_API_BEGIN();
393
  int int_len = Rf_asInteger(len);
Guolin Ke's avatar
Guolin Ke committed
394
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
395
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
396
  for (int j = 0; j < int_len; ++j) {
397
398
    tgrad[j] = static_cast<float>(REAL(grad)[j]);
    thess[j] = static_cast<float>(REAL(hess)[j]);
Guolin Ke's avatar
Guolin Ke committed
399
400
401
402
403
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

404
SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
405
406
407
408
409
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

410
SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
411
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
412
413
414
  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
415
  INTEGER(out)[0] = static_cast<int>(out_iteration);
Guolin Ke's avatar
Guolin Ke committed
416
417
418
  R_API_END();
}

419
SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
420
  SEXP out_result) {
421
  R_API_BEGIN();
422
  double* ptr_ret = REAL(out_result);
423
424
425
426
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

427
SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
428
  SEXP out_result) {
429
  R_API_BEGIN();
430
  double* ptr_ret = REAL(out_result);
431
432
433
434
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

435
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
436
  SEXP buf_len,
437
  SEXP actual_len,
438
  LGBM_SE eval_names) {
Guolin Ke's avatar
Guolin Ke committed
439
440
441
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
442
443

  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
444
445
446
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
447
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
448
449
    ptr_names[i] = names[i].data();
  }
450

Guolin Ke's avatar
Guolin Ke committed
451
  int out_len;
452
453
454
455
456
457
458
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
459
  CHECK_EQ(out_len, len);
460
  CHECK_GE(reserved_string_size, required_string_size);
461
  auto merge_names = Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
462
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
463
464
465
  R_API_END();
}

466
SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
467
  SEXP data_idx,
468
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
469
470
471
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
472
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
473
  int out_len;
474
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
475
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
476
477
478
  R_API_END();
}

479
SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
480
  SEXP data_idx,
481
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
482
483
  R_API_BEGIN();
  int64_t len;
484
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len));
485
  INTEGER(out)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
486
487
488
  R_API_END();
}

489
SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle,
490
  SEXP data_idx,
491
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
492
  R_API_BEGIN();
493
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
494
  int64_t out_len;
495
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
496
497
498
  R_API_END();
}

499
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
500
  int pred_type = C_API_PREDICT_NORMAL;
501
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
502
503
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
504
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
505
506
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
507
  if (Rf_asInteger(is_predcontrib)) {
508
509
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
510
511
512
  return pred_type;
}

513
SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle,
514
  SEXP data_filename,
515
516
517
518
519
520
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
521
522
  SEXP parameter,
  SEXP result_filename) {
Guolin Ke's avatar
Guolin Ke committed
523
  R_API_BEGIN();
524
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
525
526
527
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), CHAR(Rf_asChar(data_filename)),
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)),
    CHAR(Rf_asChar(result_filename))));
Guolin Ke's avatar
Guolin Ke committed
528
529
530
  R_API_END();
}

531
SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
532
533
534
535
536
537
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
538
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
539
  R_API_BEGIN();
540
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
541
  int64_t len = 0;
542
543
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
544
  INTEGER(out_len)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
545
546
547
  R_API_END();
}

548
SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
549
550
551
  SEXP indptr,
  SEXP indices,
  SEXP data,
552
553
554
555
556
557
558
559
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
560
  SEXP parameter,
561
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
562
  R_API_BEGIN();
563
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
564

565
566
567
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
568

569
570
571
  int64_t nindptr = Rf_asInteger(num_indptr);
  int64_t ndata = Rf_asInteger(nelem);
  int64_t nrow = Rf_asInteger(num_row);
572
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
573
574
575
576
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
577
    nrow, pred_type,  Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
578
579
580
  R_API_END();
}

581
SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
582
  SEXP data,
583
584
585
586
587
588
589
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
590
  SEXP parameter,
591
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
592
  R_API_BEGIN();
593
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
594

595
596
  int32_t nrow = Rf_asInteger(num_row);
  int32_t ncol = Rf_asInteger(num_col);
Guolin Ke's avatar
Guolin Ke committed
597

598
599
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
600
601
602
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
603
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
604
605
606
607

  R_API_END();
}

608
SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
609
610
  SEXP num_iteration,
  SEXP feature_importance_type,
611
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
612
  R_API_BEGIN();
613
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename))));
Guolin Ke's avatar
Guolin Ke committed
614
615
616
  R_API_END();
}

617
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
618
619
620
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP buffer_len,
621
  SEXP actual_len,
622
  LGBM_SE out_str) {
623
  R_API_BEGIN();
624
  int64_t out_len = 0;
625
626
627
  int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
628
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
629
630
631
  R_API_END();
}

632
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
633
634
635
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP buffer_len,
636
  SEXP actual_len,
637
  LGBM_SE out_str) {
Guolin Ke's avatar
Guolin Ke committed
638
  R_API_BEGIN();
639
  int64_t out_len = 0;
640
641
642
  int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
643
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
Guolin Ke's avatar
Guolin Ke committed
644
645
  R_API_END();
}
646
647
648

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
649
  {"LGBM_GetLastError_R"              , (DL_FUNC) &LGBM_GetLastError_R              , 0},
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 4},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 9},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 6},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 5},
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 4},
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 3},
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 2},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 2},
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 4},
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 6},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 6},
690
691
692
  {NULL, NULL, 0}
};

693
694
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

695
696
697
698
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}