lightgbm_R.cpp 33.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
25
#include <algorithm>
26

Guolin Ke's avatar
Guolin Ke committed
27
28
#define COL_MAJOR (0)

29
30
31
32
33
34
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);

Guolin Ke's avatar
Guolin Ke committed
35
36
37
#define R_API_BEGIN() \
  try {
#define R_API_END() } \
38
39
40
41
42
43
  catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
  catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(...) { Rf_error("unknown exception"); } \
  Rf_error(R_errmsg_buffer); \
  return R_NilValue; /* <- won't be reached */
Guolin Ke's avatar
Guolin Ke committed
44
45
46

#define CHECK_CALL(x) \
  if ((x) != 0) { \
47
    throw std::runtime_error(LGBM_GetLastError()); \
Guolin Ke's avatar
Guolin Ke committed
48
49
  }

50
51
52
53
54
55
56
57
58
59
60
61
62
63
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}

void LGBM_R_save_exception_msg(const std::string &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}

SEXP wrapped_R_string(void *len) {
  return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

64
65
66
67
SEXP wrapped_R_raw(void *len) {
  return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
SEXP wrapped_Rf_mkChar(void *txt) {
  return Rf_mkChar(reinterpret_cast<char*>(txt));
}

void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
  if (jump) {
    LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
    throw err;
  }
}

SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

83
84
85
86
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

87
88
89
90
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}

91
92
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
93

94
95
96
97
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

98
99
100
101
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

102
103
104
105
106
107
108
109
SEXP LGBM_NullBoosterHandleError_R() {
  Rf_error(
      "Attempting to use a Booster which no longer exists and/or cannot be restored. "
      "This can happen if you have called Booster$finalize() "
      "or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
  return R_NilValue;
}

110
111
void _AssertBoosterHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
112
    LGBM_NullBoosterHandleError_R();
113
114
115
116
117
118
119
120
121
122
123
124
  }
}

void _AssertDatasetHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
    Rf_error(
      "Attempting to use a Dataset which no longer exists. "
      "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
      "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
  }
}

125
126
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
127
  SEXP reference) {
128
  R_API_BEGIN();
129
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
130
  DatasetHandle handle = nullptr;
131
132
133
134
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
135
136
137
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
  CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
138
  R_SetExternalPtrAddr(ret, handle);
139
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
140
  UNPROTECT(3);
141
  return ret;
142
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
143
144
}

145
146
147
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
148
149
150
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
151
  SEXP parameters,
152
  SEXP reference) {
153
  R_API_BEGIN();
154
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
155
156
157
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
158
159
160
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
161
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
162
  DatasetHandle handle = nullptr;
163
164
165
166
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
167
168
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
169
    nrow, parameters_ptr, ref, &handle));
170
  R_SetExternalPtrAddr(ret, handle);
171
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
172
  UNPROTECT(2);
173
  return ret;
174
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
175
176
}

177
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
178
179
  SEXP num_row,
  SEXP num_col,
180
  SEXP parameters,
181
  SEXP reference) {
182
  R_API_BEGIN();
183
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
184
185
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
186
  double* p_mat = REAL(data);
187
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
188
  DatasetHandle handle = nullptr;
189
190
191
192
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
193
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
194
    parameters_ptr, ref, &handle));
195
  R_SetExternalPtrAddr(ret, handle);
196
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
197
  UNPROTECT(2);
198
  return ret;
199
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
200
201
}

202
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
203
  SEXP used_row_indices,
204
  SEXP len_used_row_indices,
205
  SEXP parameters) {
206
  R_API_BEGIN();
207
  _AssertDatasetHandleNotNull(handle);
208
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
209
210
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
  std::vector<int32_t> idxvec(len);
211
  // convert from one-based to zero-based index
Guolin Ke's avatar
Guolin Ke committed
212
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
213
214
  for (int32_t i = 0; i < len; ++i) {
    idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
215
  }
216
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
217
  DatasetHandle res = nullptr;
218
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
219
    idxvec.data(), len, parameters_ptr,
Guolin Ke's avatar
Guolin Ke committed
220
    &res));
221
  R_SetExternalPtrAddr(ret, res);
222
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
223
  UNPROTECT(2);
224
  return ret;
225
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
226
227
}

228
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
229
  SEXP feature_names) {
230
  R_API_BEGIN();
231
  _AssertDatasetHandleNotNull(handle);
232
  auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
237
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
238
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
239
    vec_sptr.data(), len));
240
241
  UNPROTECT(1);
  return R_NilValue;
242
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
243
244
}

245
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
246
247
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
248
  _AssertDatasetHandleNotNull(handle);
249
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
250
  int len = 0;
251
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
252
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
256
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
257
258
259
    ptr_names[i] = names[i].data();
  }
  int out_len;
260
261
262
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
263
      R_ExternalPtrAddr(handle),
264
265
266
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
267
268
269
270
271
272
273
274
275
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
276
        R_ExternalPtrAddr(handle),
277
278
279
280
281
282
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
283
  CHECK_EQ(len, out_len);
284
  feature_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
285
  for (int i = 0; i < len; ++i) {
286
    SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
287
  }
288
  UNPROTECT(2);
289
  return feature_names;
290
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
291
292
}

293
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
294
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
295
  R_API_BEGIN();
296
  _AssertDatasetHandleNotNull(handle);
297
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
298
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
299
300
301
    filename_ptr));
  UNPROTECT(1);
  return R_NilValue;
302
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
303
304
}

305
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
306
  R_API_BEGIN();
307
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
308
309
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
310
  }
311
  return R_NilValue;
312
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
313
314
}

315
SEXP LGBM_DatasetSetField_R(SEXP handle,
316
  SEXP field_name,
317
  SEXP field_data,
318
  SEXP num_element) {
319
  R_API_BEGIN();
320
  _AssertDatasetHandleNotNull(handle);
321
  int len = Rf_asInteger(num_element);
322
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
323
324
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
325
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
326
    for (int i = 0; i < len; ++i) {
327
      vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
328
    }
329
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_INT32));
330
  } else if (!strcmp("init_score", name)) {
331
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
332
333
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
334
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
335
    for (int i = 0; i < len; ++i) {
336
      vec[i] = static_cast<float>(REAL(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
337
    }
338
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
339
  }
340
341
  UNPROTECT(1);
  return R_NilValue;
342
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
343
344
}

345
SEXP LGBM_DatasetGetField_R(SEXP handle,
346
  SEXP field_name,
347
  SEXP field_data) {
348
  R_API_BEGIN();
349
  _AssertDatasetHandleNotNull(handle);
350
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
351
352
353
  int out_len = 0;
  int out_type = 0;
  const void* res;
354
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
355
356
357
  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
358
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
359
    for (int i = 0; i < out_len - 1; ++i) {
360
      INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
361
    }
Guolin Ke's avatar
Guolin Ke committed
362
363
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
364
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
365
    for (int i = 0; i < out_len; ++i) {
366
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
367
    }
Guolin Ke's avatar
Guolin Ke committed
368
369
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
370
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
371
    for (int i = 0; i < out_len; ++i) {
372
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
373
374
    }
  }
375
376
  UNPROTECT(1);
  return R_NilValue;
377
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
378
379
}

380
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
381
  SEXP field_name,
382
  SEXP out) {
383
  R_API_BEGIN();
384
  _AssertDatasetHandleNotNull(handle);
385
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
386
387
388
  int out_len = 0;
  int out_type = 0;
  const void* res;
389
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
390
391
392
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
393
  INTEGER(out)[0] = out_len;
394
395
  UNPROTECT(1);
  return R_NilValue;
396
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
397
398
}

399
400
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
401
  R_API_BEGIN();
402
403
404
405
406
  const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
  const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
  UNPROTECT(2);
  return R_NilValue;
407
  R_API_END();
408
409
}

410
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
411
  R_API_BEGIN();
412
  _AssertDatasetHandleNotNull(handle);
413
  int nrow;
414
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
415
  INTEGER(out)[0] = nrow;
416
  return R_NilValue;
417
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
418
419
}

420
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
421
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
422
  R_API_BEGIN();
423
  _AssertDatasetHandleNotNull(handle);
424
  int nfeature;
425
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
426
  INTEGER(out)[0] = nfeature;
427
  return R_NilValue;
428
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
429
430
}

431
432
433
434
435
436
437
438
439
440
441
SEXP LGBM_DatasetGetFeatureNumBin_R(SEXP handle, SEXP feature_idx, SEXP out) {
  R_API_BEGIN();
  _AssertDatasetHandleNotNull(handle);
  int feature = Rf_asInteger(feature_idx);
  int nbins;
  CHECK_CALL(LGBM_DatasetGetFeatureNumBin(R_ExternalPtrAddr(handle), feature, &nbins));
  INTEGER(out)[0] = nbins;
  return R_NilValue;
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
442
443
// --- start Booster interfaces

444
445
446
447
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

448
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
449
  R_API_BEGIN();
450
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
451
452
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
453
  }
454
  return R_NilValue;
455
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
456
457
}

458
459
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
460
  R_API_BEGIN();
461
  _AssertDatasetHandleNotNull(train_data);
462
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
463
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
464
  BoosterHandle handle = nullptr;
465
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
466
  R_SetExternalPtrAddr(ret, handle);
467
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
468
  UNPROTECT(2);
469
  return ret;
470
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
471
472
}

473
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
474
  R_API_BEGIN();
475
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
476
  int out_num_iterations = 0;
477
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
Guolin Ke's avatar
Guolin Ke committed
478
  BoosterHandle handle = nullptr;
479
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
480
  R_SetExternalPtrAddr(ret, handle);
481
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
482
  UNPROTECT(2);
483
  return ret;
484
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
485
486
}

487
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
488
  R_API_BEGIN();
489
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
490
491
  SEXP temp = NULL;
  int n_protected = 1;
492
  int out_num_iterations = 0;
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
  const char* model_str_ptr = nullptr;
  switch (TYPEOF(model_str)) {
    case RAWSXP: {
      model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
      break;
    }
    case CHARSXP: {
      model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
      break;
    }
    case STRSXP: {
      temp = PROTECT(STRING_ELT(model_str, 0));
      n_protected++;
      model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
    }
  }
Guolin Ke's avatar
Guolin Ke committed
509
  BoosterHandle handle = nullptr;
510
  CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
511
  R_SetExternalPtrAddr(ret, handle);
512
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
513
  UNPROTECT(n_protected);
514
  return ret;
515
  R_API_END();
516
517
}

518
519
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
520
  R_API_BEGIN();
521
522
  _AssertBoosterHandleNotNull(handle);
  _AssertBoosterHandleNotNull(other_handle);
523
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
524
  return R_NilValue;
525
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
526
527
}

528
529
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
530
  R_API_BEGIN();
531
532
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(valid_data);
533
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
534
  return R_NilValue;
535
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
536
537
}

538
539
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
540
  R_API_BEGIN();
541
542
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(train_data);
543
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
544
  return R_NilValue;
545
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
546
547
}

548
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
549
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
550
  R_API_BEGIN();
551
  _AssertBoosterHandleNotNull(handle);
552
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
553
554
555
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
  UNPROTECT(1);
  return R_NilValue;
556
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
557
558
}

559
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
560
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
561
  R_API_BEGIN();
562
  _AssertBoosterHandleNotNull(handle);
563
  int num_class;
564
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
565
  INTEGER(out)[0] = num_class;
566
  return R_NilValue;
567
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
568
569
}

570
571
572
573
574
575
576
577
578
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int out = 0;
  CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
  return Rf_ScalarInteger(out);
  R_API_END();
}

579
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
580
  R_API_BEGIN();
581
  _AssertBoosterHandleNotNull(handle);
582
  int is_finished = 0;
583
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
584
  return R_NilValue;
585
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
586
587
}

588
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
589
590
  SEXP grad,
  SEXP hess,
591
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
592
  R_API_BEGIN();
593
  _AssertBoosterHandleNotNull(handle);
594
  int is_finished = 0;
595
  int int_len = Rf_asInteger(len);
Guolin Ke's avatar
Guolin Ke committed
596
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
597
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
598
  for (int j = 0; j < int_len; ++j) {
599
600
    tgrad[j] = static_cast<float>(REAL(grad)[j]);
    thess[j] = static_cast<float>(REAL(hess)[j]);
Guolin Ke's avatar
Guolin Ke committed
601
  }
602
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
603
  return R_NilValue;
604
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
605
606
}

607
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
608
  R_API_BEGIN();
609
  _AssertBoosterHandleNotNull(handle);
610
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
611
  return R_NilValue;
612
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
613
614
}

615
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
616
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
617
  R_API_BEGIN();
618
  _AssertBoosterHandleNotNull(handle);
619
  int out_iteration;
620
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
621
  INTEGER(out)[0] = out_iteration;
622
  return R_NilValue;
623
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
624
625
}

626
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
627
  SEXP out_result) {
628
  R_API_BEGIN();
629
  _AssertBoosterHandleNotNull(handle);
630
  double* ptr_ret = REAL(out_result);
631
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
632
  return R_NilValue;
633
  R_API_END();
634
635
}

636
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
637
  SEXP out_result) {
638
  R_API_BEGIN();
639
  _AssertBoosterHandleNotNull(handle);
640
  double* ptr_ret = REAL(out_result);
641
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
642
  return R_NilValue;
643
  R_API_END();
644
645
}

646
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
647
648
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
649
  _AssertBoosterHandleNotNull(handle);
650
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
651
  int len;
652
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
653
  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
654
655
656
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
657
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
658
659
    ptr_names[i] = names[i].data();
  }
660

Guolin Ke's avatar
Guolin Ke committed
661
  int out_len;
662
663
664
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
665
      R_ExternalPtrAddr(handle),
666
667
668
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
669
670
671
672
673
674
675
676
677
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
678
        R_ExternalPtrAddr(handle),
679
680
681
682
683
684
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
685
  CHECK_EQ(out_len, len);
686
  eval_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
687
  for (int i = 0; i < len; ++i) {
688
    SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
689
  }
690
  UNPROTECT(2);
691
  return eval_names;
692
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
693
694
}

695
SEXP LGBM_BoosterGetEval_R(SEXP handle,
696
  SEXP data_idx,
697
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
698
  R_API_BEGIN();
699
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
700
  int len;
701
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
702
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
703
  int out_len;
704
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
705
  CHECK_EQ(out_len, len);
706
  return R_NilValue;
707
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
708
709
}

710
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
711
  SEXP data_idx,
712
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
713
  R_API_BEGIN();
714
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
715
  int64_t len;
716
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
717
  INTEGER(out)[0] = static_cast<int>(len);
718
  return R_NilValue;
719
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
720
721
}

722
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
723
  SEXP data_idx,
724
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
725
  R_API_BEGIN();
726
  _AssertBoosterHandleNotNull(handle);
727
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
728
  int64_t out_len;
729
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
730
  return R_NilValue;
731
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
732
733
}

734
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
735
  int pred_type = C_API_PREDICT_NORMAL;
736
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
737
738
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
739
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
740
741
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
742
  if (Rf_asInteger(is_predcontrib)) {
743
744
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
745
746
747
  return pred_type;
}

748
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
749
  SEXP data_filename,
750
751
752
753
754
755
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
756
757
  SEXP parameter,
  SEXP result_filename) {
758
  R_API_BEGIN();
759
  _AssertBoosterHandleNotNull(handle);
760
761
762
  const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
763
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
764
765
766
767
768
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
    result_filename_ptr));
  UNPROTECT(3);
  return R_NilValue;
769
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
770
771
}

772
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
773
774
775
776
777
778
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
779
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
780
  R_API_BEGIN();
781
  _AssertBoosterHandleNotNull(handle);
782
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
783
  int64_t len = 0;
784
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
785
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
786
  INTEGER(out_len)[0] = static_cast<int>(len);
787
  return R_NilValue;
788
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
789
790
}

791
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
792
793
794
  SEXP indptr,
  SEXP indices,
  SEXP data,
795
796
797
798
799
800
801
802
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
803
  SEXP parameter,
804
  SEXP out_result) {
805
  R_API_BEGIN();
806
  _AssertBoosterHandleNotNull(handle);
807
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
808
  const int* p_indptr = INTEGER(indptr);
809
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
810
  const double* p_data = REAL(data);
811
812
813
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
814
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
815
  int64_t out_len;
816
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
817
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
818
819
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
820
821
822
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
823
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
824
825
}

826
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
827
  SEXP data,
828
829
830
831
832
833
834
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
835
  SEXP parameter,
836
  SEXP out_result) {
837
  R_API_BEGIN();
838
  _AssertBoosterHandleNotNull(handle);
839
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
840
841
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
842
843
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
844
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
Guolin Ke's avatar
Guolin Ke committed
845
  int64_t out_len;
846
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
847
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
848
849
850
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
851
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
852
853
}

854
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
855
856
  SEXP num_iteration,
  SEXP feature_importance_type,
857
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
858
  R_API_BEGIN();
859
  _AssertBoosterHandleNotNull(handle);
860
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
861
862
863
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
  UNPROTECT(1);
  return R_NilValue;
864
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
865
866
}

867
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
868
  SEXP num_iteration,
869
  SEXP feature_importance_type) {
870
871
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
872
  _AssertBoosterHandleNotNull(handle);
873
  int64_t out_len = 0;
874
  int64_t buf_len = 1024 * 1024;
875
876
  int num_iter = Rf_asInteger(num_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
877
  std::vector<char> inner_char_buf(buf_len);
878
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
879
880
  SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
  // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
881
  if (out_len > buf_len) {
882
883
884
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
  } else {
    std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
885
  }
886
  UNPROTECT(2);
887
  return model_str;
888
  R_API_END();
889
890
}

891
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
892
  SEXP num_iteration,
893
  SEXP feature_importance_type) {
894
895
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
896
  _AssertBoosterHandleNotNull(handle);
897
  SEXP model_str;
898
  int64_t out_len = 0;
899
  int64_t buf_len = 1024 * 1024;
900
901
  int num_iter = Rf_asInteger(num_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
902
  std::vector<char> inner_char_buf(buf_len);
903
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
904
905
906
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
907
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
908
  }
909
910
911
  model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
912
  return model_str;
913
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
914
}
915

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
SEXP LGBM_DumpParamAliases_R() {
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
  SEXP aliases_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_DumpParamAliases(buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_DumpParamAliases(out_len, &out_len, inner_char_buf.data()));
  }
  aliases_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(aliases_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
  return aliases_str;
  R_API_END();
}

936
937
// .Call() calls
static const R_CallMethodDef CallEntries[] = {
938
939
940
941
942
  {"LGBM_HandleIsNull_R"              , (DL_FUNC) &LGBM_HandleIsNull_R              , 1},
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 3},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 8},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 5},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 4},
943
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
944
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 1},
945
946
947
948
949
950
951
952
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
953
  {"LGBM_DatasetGetFeatureNumBin_R"   , (DL_FUNC) &LGBM_DatasetGetFeatureNumBin_R   , 3},
954
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 2},
955
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
956
957
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 1},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 1},
958
959
960
961
962
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
963
  {"LGBM_BoosterGetNumFeature_R"      , (DL_FUNC) &LGBM_BoosterGetNumFeature_R      , 1},
964
965
966
967
968
969
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
970
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 1},
971
972
973
974
975
976
977
978
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
979
980
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 3},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 3},
981
  {"LGBM_NullBoosterHandleError_R"    , (DL_FUNC) &LGBM_NullBoosterHandleError_R    , 0},
982
  {"LGBM_DumpParamAliases_R"          , (DL_FUNC) &LGBM_DumpParamAliases_R          , 0},
983
984
985
  {NULL, NULL, 0}
};

986
987
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

988
989
990
991
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}