lightgbm_R.cpp 52.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>
14
#include <R_ext/Altrep.h>
15

16
#ifndef R_NO_REMAP
17
#define R_NO_REMAP
18
19
20
#endif

#ifndef R_USE_C99_IN_CXX
21
#define R_USE_C99_IN_CXX
22
23
#endif

24
25
#include <R_ext/Error.h>

26
27
#include <string>
#include <cstdio>
28
#include <cstdlib>
29
30
31
32
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
33
#include <algorithm>
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <type_traits>

R_altrep_class_t lgb_altrepped_char_vec;
R_altrep_class_t lgb_altrepped_int_arr;
R_altrep_class_t lgb_altrepped_dbl_arr;

template <class T>
void delete_cpp_array(SEXP R_ptr) {
  T *ptr_to_cpp_obj = static_cast<T*>(R_ExternalPtrAddr(R_ptr));
  delete[] ptr_to_cpp_obj;
  R_ClearExternalPtr(R_ptr);
}

void delete_cpp_char_vec(SEXP R_ptr) {
  std::vector<char> *ptr_to_cpp_obj = static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_ptr));
  delete ptr_to_cpp_obj;
  R_ClearExternalPtr(R_ptr);
}

// Note: MSVC has issues with Altrep classes, so they are disabled for it.
// See: https://github.com/microsoft/LightGBM/pull/6213#issuecomment-2111025768
#ifdef _MSC_VER
#  define LGB_NO_ALTREP
#endif

#ifndef LGB_NO_ALTREP
SEXP make_altrepped_raw_vec(void *void_ptr) {
  std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
62
63
  SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  SEXP R_raw = Rf_protect(R_new_altrep(lgb_altrepped_char_vec, R_NilValue, R_NilValue));
64
65
66
67
68
69

  R_SetExternalPtrAddr(R_ptr, ptr_to_cpp_vec->get());
  R_RegisterCFinalizerEx(R_ptr, delete_cpp_char_vec, TRUE);
  ptr_to_cpp_vec->release();

  R_set_altrep_data1(R_raw, R_ptr);
70
  Rf_unprotect(2);
71
72
73
74
75
76
  return R_raw;
}
#else
SEXP make_r_raw_vec(void *void_ptr) {
  std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
  R_xlen_t len = ptr_to_cpp_vec->get()->size();
77
  SEXP out = Rf_protect(Rf_allocVector(RAWSXP, len));
78
  std::copy(ptr_to_cpp_vec->get()->begin(), ptr_to_cpp_vec->get()->end(), reinterpret_cast<char*>(RAW(out)));
79
  Rf_unprotect(1);
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  return out;
}
#define make_altrepped_raw_vec make_r_raw_vec
#endif

std::vector<char>* get_ptr_from_altrepped_raw(SEXP R_raw) {
  return static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_altrep_data1(R_raw)));
}

R_xlen_t get_altrepped_raw_len(SEXP R_raw) {
  return get_ptr_from_altrepped_raw(R_raw)->size();
}

const void* get_altrepped_raw_dataptr_or_null(SEXP R_raw) {
  return get_ptr_from_altrepped_raw(R_raw)->data();
}

void* get_altrepped_raw_dataptr(SEXP R_raw, Rboolean writeable) {
  return get_ptr_from_altrepped_raw(R_raw)->data();
}

#ifndef LGB_NO_ALTREP
template <class T>
R_altrep_class_t get_altrep_class_for_type() {
  if (std::is_same<T, double>::value) {
    return lgb_altrepped_dbl_arr;
  } else {
    return lgb_altrepped_int_arr;
  }
}
#else
template <class T>
SEXPTYPE get_sexptype_class_for_type() {
  if (std::is_same<T, double>::value) {
    return REALSXP;
  } else {
    return INTSXP;
  }
}

template <class T>
T* get_r_vec_ptr(SEXP x) {
  if (std::is_same<T, double>::value) {
    return static_cast<T*>(static_cast<void*>(REAL(x)));
  } else {
    return static_cast<T*>(static_cast<void*>(INTEGER(x)));
  }
}
#endif

template <class T>
struct arr_and_len {
  T *arr;
  int64_t len;
};

#ifndef LGB_NO_ALTREP
template <class T>
SEXP make_altrepped_vec_from_arr(void *void_ptr) {
  T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
  uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
141
142
143
  SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  SEXP R_len = Rf_protect(Rf_allocVector(REALSXP, 1));
  SEXP R_vec = Rf_protect(R_new_altrep(get_altrep_class_for_type<T>(), R_NilValue, R_NilValue));
144
145
146
147
148
149
150

  REAL(R_len)[0] = static_cast<double>(len);
  R_SetExternalPtrAddr(R_ptr, arr);
  R_RegisterCFinalizerEx(R_ptr, delete_cpp_array<T>, TRUE);

  R_set_altrep_data1(R_vec, R_ptr);
  R_set_altrep_data2(R_vec, R_len);
151
  Rf_unprotect(3);
152
153
154
155
156
157
158
  return R_vec;
}
#else
template <class T>
SEXP make_R_vec_from_arr(void *void_ptr) {
  T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
  uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
159
  SEXP out = Rf_protect(Rf_allocVector(get_sexptype_class_for_type<T>(), len));
160
  std::copy(arr, arr + len, get_r_vec_ptr<T>(out));
161
  Rf_unprotect(1);
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
  return out;
}
#define make_altrepped_vec_from_arr make_R_vec_from_arr
#endif

R_xlen_t get_altrepped_vec_len(SEXP R_vec) {
  return static_cast<R_xlen_t>(Rf_asReal(R_altrep_data2(R_vec)));
}

const void* get_altrepped_vec_dataptr_or_null(SEXP R_vec) {
  return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}

void* get_altrepped_vec_dataptr(SEXP R_vec, Rboolean writeable) {
  return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}
178

Guolin Ke's avatar
Guolin Ke committed
179
180
#define COL_MAJOR (0)

181
182
183
184
185
186
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);

Guolin Ke's avatar
Guolin Ke committed
187
188
189
#define R_API_BEGIN() \
  try {
#define R_API_END() } \
190
191
192
193
  catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
  catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(...) { Rf_error("unknown exception"); } \
194
  Rf_error("%s", R_errmsg_buffer); \
195
  return R_NilValue; /* <- won't be reached */
Guolin Ke's avatar
Guolin Ke committed
196
197
198

#define CHECK_CALL(x) \
  if ((x) != 0) { \
199
    throw std::runtime_error(LGBM_GetLastError()); \
Guolin Ke's avatar
Guolin Ke committed
200
201
  }

202
203
204
205
206
207
208
209
210
211
212
213
214
215
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}

void LGBM_R_save_exception_msg(const std::string &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}

SEXP wrapped_R_string(void *len) {
  return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

216
217
218
219
SEXP wrapped_R_raw(void *len) {
  return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

220
221
222
223
224
225
226
227
SEXP wrapped_R_int(void *len) {
  return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_real(void *len) {
  return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
SEXP wrapped_Rf_mkChar(void *txt) {
  return Rf_mkChar(reinterpret_cast<char*>(txt));
}

void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
  if (jump) {
    LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
    throw err;
  }
}

SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

243
244
245
246
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

247
248
249
250
251
252
253
254
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

255
256
257
258
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}

259
260
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
261

262
263
264
265
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

266
267
268
269
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

270
271
272
273
274
275
276
277
SEXP LGBM_NullBoosterHandleError_R() {
  Rf_error(
      "Attempting to use a Booster which no longer exists and/or cannot be restored. "
      "This can happen if you have called Booster$finalize() "
      "or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
  return R_NilValue;
}

278
279
void _AssertBoosterHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
280
    LGBM_NullBoosterHandleError_R();
281
282
283
284
285
286
287
288
289
290
291
292
  }
}

void _AssertDatasetHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
    Rf_error(
      "Attempting to use a Dataset which no longer exists. "
      "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
      "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
  }
}

293
294
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
295
  SEXP reference) {
296
  R_API_BEGIN();
297
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
298
  DatasetHandle handle = nullptr;
299
300
301
302
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
303
304
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
305
  CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
306
  R_SetExternalPtrAddr(ret, handle);
307
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
308
  Rf_unprotect(3);
309
  return ret;
310
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
314
315
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
316
317
318
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
319
  SEXP parameters,
320
  SEXP reference) {
321
  R_API_BEGIN();
322
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
323
324
325
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
326
327
328
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
329
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
330
  DatasetHandle handle = nullptr;
331
332
333
334
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
335
336
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
337
    nrow, parameters_ptr, ref, &handle));
338
  R_SetExternalPtrAddr(ret, handle);
339
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
340
  Rf_unprotect(2);
341
  return ret;
342
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
343
344
}

345
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
346
347
  SEXP num_row,
  SEXP num_col,
348
  SEXP parameters,
349
  SEXP reference) {
350
  R_API_BEGIN();
351
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
352
353
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
354
  double* p_mat = REAL(data);
355
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
356
  DatasetHandle handle = nullptr;
357
358
359
360
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
361
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
362
    parameters_ptr, ref, &handle));
363
  R_SetExternalPtrAddr(ret, handle);
364
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
365
  Rf_unprotect(2);
366
  return ret;
367
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
368
369
}

370
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
371
  SEXP used_row_indices,
372
  SEXP len_used_row_indices,
373
  SEXP parameters) {
374
  R_API_BEGIN();
375
  _AssertDatasetHandleNotNull(handle);
376
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
377
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
378
  std::unique_ptr<int32_t[]> idxvec(new int32_t[len]);
379
  // convert from one-based to zero-based index
380
  const int *used_row_indices_ = INTEGER(used_row_indices);
381
382
383
#ifndef _MSC_VER
#pragma omp simd
#endif
384
  for (int32_t i = 0; i < len; ++i) {
385
    idxvec[i] = static_cast<int32_t>(used_row_indices_[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
386
  }
387
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
388
  DatasetHandle res = nullptr;
389
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
390
    idxvec.get(), len, parameters_ptr,
Guolin Ke's avatar
Guolin Ke committed
391
    &res));
392
  R_SetExternalPtrAddr(ret, res);
393
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
394
  Rf_unprotect(2);
395
  return ret;
396
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
397
398
}

399
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
400
  SEXP feature_names) {
401
  R_API_BEGIN();
402
  _AssertDatasetHandleNotNull(handle);
403
  auto vec_names = Split(CHAR(Rf_protect(Rf_asChar(feature_names))), '\t');
Guolin Ke's avatar
Guolin Ke committed
404
  int len = static_cast<int>(vec_names.size());
405
  std::unique_ptr<const char*[]> vec_sptr(new const char*[len]);
Guolin Ke's avatar
Guolin Ke committed
406
  for (int i = 0; i < len; ++i) {
407
    vec_sptr[i] = vec_names[i].c_str();
Guolin Ke's avatar
Guolin Ke committed
408
  }
409
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
410
    vec_sptr.get(), len));
411
  Rf_unprotect(1);
412
  return R_NilValue;
413
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
414
415
}

416
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
417
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
418
  R_API_BEGIN();
419
  _AssertDatasetHandleNotNull(handle);
420
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
421
  int len = 0;
422
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
423
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
424
425
426
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
427
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
428
429
430
    ptr_names[i] = names[i].data();
  }
  int out_len;
431
432
433
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
434
      R_ExternalPtrAddr(handle),
435
436
437
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
438
439
440
441
442
443
444
445
446
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
447
        R_ExternalPtrAddr(handle),
448
449
450
451
452
453
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
454
  CHECK_EQ(len, out_len);
455
  feature_names = Rf_protect(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
456
  for (int i = 0; i < len; ++i) {
457
    SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
458
  }
459
  Rf_unprotect(2);
460
  return feature_names;
461
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
462
463
}

464
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
465
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
466
  R_API_BEGIN();
467
  _AssertDatasetHandleNotNull(handle);
468
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
469
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
470
    filename_ptr));
471
  Rf_unprotect(1);
472
  return R_NilValue;
473
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
474
475
}

476
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
477
  R_API_BEGIN();
478
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
479
480
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
481
  }
482
  return R_NilValue;
483
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
484
485
}

486
SEXP LGBM_DatasetSetField_R(SEXP handle,
487
  SEXP field_name,
488
  SEXP field_data,
489
  SEXP num_element) {
490
  R_API_BEGIN();
491
  _AssertDatasetHandleNotNull(handle);
492
  int len = Rf_asInteger(num_element);
493
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
494
  if (!strcmp("group", name) || !strcmp("query", name)) {
495
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, INTEGER(field_data), len, C_API_DTYPE_INT32));
496
  } else if (!strcmp("init_score", name)) {
497
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
498
  } else {
499
500
501
    std::unique_ptr<float[]> vec(new float[len]);
    std::copy(REAL(field_data), REAL(field_data) + len, vec.get());
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.get(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
502
  }
503
  Rf_unprotect(1);
504
  return R_NilValue;
505
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
506
507
}

508
SEXP LGBM_DatasetGetField_R(SEXP handle,
509
  SEXP field_name,
510
  SEXP field_data) {
511
  R_API_BEGIN();
512
  _AssertDatasetHandleNotNull(handle);
513
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
514
515
516
  int out_len = 0;
  int out_type = 0;
  const void* res;
517
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
518
519
520
  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
521
    int *field_data_ = INTEGER(field_data);
522
523
524
#ifndef _MSC_VER
#pragma omp simd
#endif
Guolin Ke's avatar
Guolin Ke committed
525
    for (int i = 0; i < out_len - 1; ++i) {
526
      field_data_[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
527
    }
Guolin Ke's avatar
Guolin Ke committed
528
529
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
530
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
531
532
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
533
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
534
  }
535
  Rf_unprotect(1);
536
  return R_NilValue;
537
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
538
539
}

540
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
541
  SEXP field_name,
542
  SEXP out) {
543
  R_API_BEGIN();
544
  _AssertDatasetHandleNotNull(handle);
545
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
546
547
548
  int out_len = 0;
  int out_type = 0;
  const void* res;
549
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
550
551
552
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
553
  INTEGER(out)[0] = out_len;
554
  Rf_unprotect(1);
555
  return R_NilValue;
556
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
557
558
}

559
560
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
561
  R_API_BEGIN();
562
563
  const char* old_params_ptr = CHAR(Rf_protect(Rf_asChar(old_params)));
  const char* new_params_ptr = CHAR(Rf_protect(Rf_asChar(new_params)));
564
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
565
  Rf_unprotect(2);
566
  return R_NilValue;
567
  R_API_END();
568
569
}

570
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
571
  R_API_BEGIN();
572
  _AssertDatasetHandleNotNull(handle);
573
  int nrow;
574
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
575
  INTEGER(out)[0] = nrow;
576
  return R_NilValue;
577
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
578
579
}

580
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
581
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
582
  R_API_BEGIN();
583
  _AssertDatasetHandleNotNull(handle);
584
  int nfeature;
585
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
586
  INTEGER(out)[0] = nfeature;
587
  return R_NilValue;
588
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
589
590
}

591
592
593
594
595
596
597
598
599
600
601
SEXP LGBM_DatasetGetFeatureNumBin_R(SEXP handle, SEXP feature_idx, SEXP out) {
  R_API_BEGIN();
  _AssertDatasetHandleNotNull(handle);
  int feature = Rf_asInteger(feature_idx);
  int nbins;
  CHECK_CALL(LGBM_DatasetGetFeatureNumBin(R_ExternalPtrAddr(handle), feature, &nbins));
  INTEGER(out)[0] = nbins;
  return R_NilValue;
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
602
603
// --- start Booster interfaces

604
605
606
607
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

608
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
609
  R_API_BEGIN();
610
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
611
612
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
613
  }
614
  return R_NilValue;
615
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
616
617
}

618
619
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
620
  R_API_BEGIN();
621
  _AssertDatasetHandleNotNull(train_data);
622
623
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
624
  BoosterHandle handle = nullptr;
625
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
626
  R_SetExternalPtrAddr(ret, handle);
627
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
628
  Rf_unprotect(2);
629
  return ret;
630
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
631
632
}

633
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
634
  R_API_BEGIN();
635
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
636
  int out_num_iterations = 0;
637
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
Guolin Ke's avatar
Guolin Ke committed
638
  BoosterHandle handle = nullptr;
639
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
640
  R_SetExternalPtrAddr(ret, handle);
641
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
642
  Rf_unprotect(2);
643
  return ret;
644
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
645
646
}

647
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
648
  R_API_BEGIN();
649
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
650
651
  SEXP temp = NULL;
  int n_protected = 1;
652
  int out_num_iterations = 0;
653
654
655
656
657
658
659
660
661
662
663
  const char* model_str_ptr = nullptr;
  switch (TYPEOF(model_str)) {
    case RAWSXP: {
      model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
      break;
    }
    case CHARSXP: {
      model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
      break;
    }
    case STRSXP: {
664
      temp = Rf_protect(STRING_ELT(model_str, 0));
665
666
667
668
      n_protected++;
      model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
    }
  }
Guolin Ke's avatar
Guolin Ke committed
669
  BoosterHandle handle = nullptr;
670
  CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
671
  R_SetExternalPtrAddr(ret, handle);
672
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
673
  Rf_unprotect(n_protected);
674
  return ret;
675
  R_API_END();
676
677
}

678
679
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
680
  R_API_BEGIN();
681
682
  _AssertBoosterHandleNotNull(handle);
  _AssertBoosterHandleNotNull(other_handle);
683
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
684
  return R_NilValue;
685
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
686
687
}

688
689
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
690
  R_API_BEGIN();
691
692
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(valid_data);
693
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
694
  return R_NilValue;
695
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
696
697
}

698
699
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
700
  R_API_BEGIN();
701
702
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(train_data);
703
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
704
  return R_NilValue;
705
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
706
707
}

708
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
709
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
710
  R_API_BEGIN();
711
  _AssertBoosterHandleNotNull(handle);
712
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
713
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
714
  Rf_unprotect(1);
715
  return R_NilValue;
716
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
717
718
}

719
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
720
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
721
  R_API_BEGIN();
722
  _AssertBoosterHandleNotNull(handle);
723
  int num_class;
724
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
725
  INTEGER(out)[0] = num_class;
726
  return R_NilValue;
727
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
728
729
}

730
731
732
733
734
735
736
737
738
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int out = 0;
  CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
  return Rf_ScalarInteger(out);
  R_API_END();
}

739
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
740
  R_API_BEGIN();
741
  _AssertBoosterHandleNotNull(handle);
742
  int is_finished = 0;
743
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
744
  return R_NilValue;
745
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
746
747
}

748
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
749
750
  SEXP grad,
  SEXP hess,
751
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
752
  R_API_BEGIN();
753
  _AssertBoosterHandleNotNull(handle);
754
  int is_finished = 0;
755
  int int_len = Rf_asInteger(len);
756
757
758
759
  std::unique_ptr<float[]> tgrad(new float[int_len]), thess(new float[int_len]);
  std::copy(REAL(grad), REAL(grad) + int_len, tgrad.get());
  std::copy(REAL(hess), REAL(hess) + int_len, thess.get());
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.get(), thess.get(), &is_finished));
760
  return R_NilValue;
761
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
762
763
}

764
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
765
  R_API_BEGIN();
766
  _AssertBoosterHandleNotNull(handle);
767
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
768
  return R_NilValue;
769
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
770
771
}

772
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
773
  R_API_BEGIN();
774
  _AssertBoosterHandleNotNull(handle);
775
  int out_iteration;
776
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
777
  INTEGER(out)[0] = out_iteration;
778
  return R_NilValue;
779
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
780
781
}

782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
SEXP LGBM_BoosterNumModelPerIteration_R(SEXP handle, SEXP out) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int models_per_iter;
  CHECK_CALL(LGBM_BoosterNumModelPerIteration(R_ExternalPtrAddr(handle), &models_per_iter));
  INTEGER(out)[0] = models_per_iter;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterNumberOfTotalModel_R(SEXP handle, SEXP out) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int total_models;
  CHECK_CALL(LGBM_BoosterNumberOfTotalModel(R_ExternalPtrAddr(handle), &total_models));
  INTEGER(out)[0] = total_models;
  return R_NilValue;
  R_API_END();
}

802
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
803
  SEXP out_result) {
804
  R_API_BEGIN();
805
  _AssertBoosterHandleNotNull(handle);
806
  double* ptr_ret = REAL(out_result);
807
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
808
  return R_NilValue;
809
  R_API_END();
810
811
}

812
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
813
  SEXP out_result) {
814
  R_API_BEGIN();
815
  _AssertBoosterHandleNotNull(handle);
816
  double* ptr_ret = REAL(out_result);
817
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
818
  return R_NilValue;
819
  R_API_END();
820
821
}

822
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
823
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
824
  R_API_BEGIN();
825
  _AssertBoosterHandleNotNull(handle);
826
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
827
  int len;
828
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
829
  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
830
831
832
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
833
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
834
835
    ptr_names[i] = names[i].data();
  }
836

Guolin Ke's avatar
Guolin Ke committed
837
  int out_len;
838
839
840
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
841
      R_ExternalPtrAddr(handle),
842
843
844
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
845
846
847
848
849
850
851
852
853
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
854
        R_ExternalPtrAddr(handle),
855
856
857
858
859
860
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
861
  CHECK_EQ(out_len, len);
862
  eval_names = Rf_protect(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
863
  for (int i = 0; i < len; ++i) {
864
    SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
865
  }
866
  Rf_unprotect(2);
867
  return eval_names;
868
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
869
870
}

871
SEXP LGBM_BoosterGetEval_R(SEXP handle,
872
  SEXP data_idx,
873
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
874
  R_API_BEGIN();
875
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
876
  int len;
877
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
878
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
879
  int out_len;
880
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
881
  CHECK_EQ(out_len, len);
882
  return R_NilValue;
883
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
884
885
}

886
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
887
  SEXP data_idx,
888
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
889
  R_API_BEGIN();
890
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
891
  int64_t len;
892
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
893
  INTEGER(out)[0] = static_cast<int>(len);
894
  return R_NilValue;
895
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
896
897
}

898
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
899
  SEXP data_idx,
900
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
901
  R_API_BEGIN();
902
  _AssertBoosterHandleNotNull(handle);
903
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
904
  int64_t out_len;
905
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
906
  return R_NilValue;
907
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
908
909
}

910
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
911
  int pred_type = C_API_PREDICT_NORMAL;
912
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
913
914
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
915
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
916
917
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
918
  if (Rf_asInteger(is_predcontrib)) {
919
920
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
921
922
923
  return pred_type;
}

924
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
925
  SEXP data_filename,
926
927
928
929
930
931
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
932
933
  SEXP parameter,
  SEXP result_filename) {
934
  R_API_BEGIN();
935
  _AssertBoosterHandleNotNull(handle);
936
937
938
  const char* data_filename_ptr = CHAR(Rf_protect(Rf_asChar(data_filename)));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
  const char* result_filename_ptr = CHAR(Rf_protect(Rf_asChar(result_filename)));
939
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
940
941
942
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
    result_filename_ptr));
943
  Rf_unprotect(3);
944
  return R_NilValue;
945
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
946
947
}

948
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
949
950
951
952
953
954
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
955
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
956
  R_API_BEGIN();
957
  _AssertBoosterHandleNotNull(handle);
958
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
959
  int64_t len = 0;
960
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
961
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
962
  INTEGER(out_len)[0] = static_cast<int>(len);
963
  return R_NilValue;
964
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
965
966
}

967
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
968
969
970
  SEXP indptr,
  SEXP indices,
  SEXP data,
971
972
973
974
975
976
977
978
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
979
  SEXP parameter,
980
  SEXP out_result) {
981
  R_API_BEGIN();
982
  _AssertBoosterHandleNotNull(handle);
983
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
984
  const int* p_indptr = INTEGER(indptr);
985
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
986
  const double* p_data = REAL(data);
987
988
989
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
990
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
991
  int64_t out_len;
992
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
993
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
994
995
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
996
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
997
  Rf_unprotect(1);
998
  return R_NilValue;
999
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1000
1001
}

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
SEXP LGBM_BoosterPredictForCSR_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1017
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1018
1019
1020
1021
1022
1023
1024
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSR(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data), Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
1025
  Rf_unprotect(1);
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRow_R(SEXP handle,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1044
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1045
1046
1047
1048
1049
1050
1051
1052
1053
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRow(R_ExternalPtrAddr(handle),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    2, nnz, Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
1054
  Rf_unprotect(1);
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
  return R_NilValue;
  R_API_END();
}

void LGBM_FastConfigFree_wrapped(SEXP handle) {
  LGBM_FastConfigFree(static_cast<FastConfigHandle*>(R_ExternalPtrAddr(handle)));
}

SEXP LGBM_BoosterPredictForCSRSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1074
1075
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1076
1077
1078
1079
1080
1081
1082
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
1083
  Rf_unprotect(2);
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRowFast_R(SEXP handle_fastConfig,
  SEXP indices,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data),
    2, nnz,
    &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1105
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
1106
  SEXP data,
1107
1108
1109
1110
1111
1112
1113
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
1114
  SEXP parameter,
1115
  SEXP out_result) {
1116
  R_API_BEGIN();
1117
  _AssertBoosterHandleNotNull(handle);
1118
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1119
1120
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
1121
1122
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
1123
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
Guolin Ke's avatar
Guolin Ke committed
1124
  int64_t out_len;
1125
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
1126
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
1127
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
1128
  Rf_unprotect(1);
1129
  return R_NilValue;
1130
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1131
1132
}

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
struct SparseOutputPointers {
  void* indptr;
  int32_t* indices;
  void* data;
  SparseOutputPointers(void* indptr, int32_t* indices, void* data)
  : indptr(indptr), indices(indices), data(data) {}
};

void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
  LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
  delete ptr;
}

SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP is_csr,
  SEXP nrows,
  SEXP ncols,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
1156
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1157
1158
1159
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  const char* out_names[] = {"indptr", "indices", "data", ""};
1160
1161
  SEXP out = Rf_protect(Rf_mkNamed(VECSXP, out_names));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185

  int64_t out_len[2];
  void *out_indptr;
  int32_t *out_indices;
  void *out_data;

  CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data),
    Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
    C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr,
    Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
    out_len, &out_indptr, &out_indices, &out_data));

  std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
    new SparseOutputPointers(
      out_indptr,
      out_indices,
      out_data),
    &delete_SparseOutputPointers
  };

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
  arr_and_len<int> indptr_str{static_cast<int*>(out_indptr), out_len[1]};
  SET_VECTOR_ELT(
    out, 0,
    R_UnwindProtect(make_altrepped_vec_from_arr<int>,
      static_cast<void*>(&indptr_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->indptr = nullptr;

  arr_and_len<int> indices_str{static_cast<int*>(out_indices), out_len[0]};
  SET_VECTOR_ELT(
    out, 1,
    R_UnwindProtect(make_altrepped_vec_from_arr<int>,
      static_cast<void*>(&indices_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->indices = nullptr;

  arr_and_len<double> data_str{static_cast<double*>(out_data), out_len[0]};
  SET_VECTOR_ELT(
    out, 2,
    R_UnwindProtect(make_altrepped_vec_from_arr<double>,
      static_cast<void*>(&data_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->data = nullptr;
1206

1207
  Rf_unprotect(3);
1208
1209
1210
1211
  return out;
  R_API_END();
}

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
SEXP LGBM_BoosterPredictForMatSingleRow_R(SEXP handle,
  SEXP data,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1224
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1225
1226
1227
1228
1229
1230
  double* ptr_ret = REAL(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRow(R_ExternalPtrAddr(handle),
    REAL(data), C_API_DTYPE_FLOAT64, Rf_xlength(data), 1,
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, ptr_ret));
1231
  Rf_unprotect(1);
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1247
1248
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1249
1250
1251
1252
1253
1254
1255
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
1256
  Rf_unprotect(2);
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    REAL(data), &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1272
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
1273
1274
  SEXP num_iteration,
  SEXP feature_importance_type,
1275
1276
  SEXP filename,
  SEXP start_iteration) {
Guolin Ke's avatar
Guolin Ke committed
1277
  R_API_BEGIN();
1278
  _AssertBoosterHandleNotNull(handle);
1279
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
1280
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
1281
  Rf_unprotect(1);
1282
  return R_NilValue;
1283
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1284
1285
}

1286
1287
1288
1289
1290
1291
1292
1293
// Note: for some reason, MSVC crashes when an error is thrown here
// if the buffer variable is defined as 'std::unique_ptr<std::vector<char>>',
// but not if it is defined as '<std::vector<char>'.
#ifndef _MSC_VER
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP start_iteration) {
1294
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  int num_iter = Rf_asInteger(num_iteration);
  int start_iter = Rf_asInteger(start_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
  std::unique_ptr<std::vector<char>> inner_char_buf(new std::vector<char>(buf_len));
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf->data()));
  inner_char_buf->resize(out_len);
  if (out_len > buf_len) {
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf->data()));
  }
  SEXP out = R_UnwindProtect(make_altrepped_raw_vec, &inner_char_buf, throw_R_memerr, &cont_token, cont_token);
1309
  Rf_unprotect(1);
1310
1311
1312
1313
  return out;
  R_API_END();
}
#else
1314
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
1315
  SEXP num_iteration,
1316
1317
  SEXP feature_importance_type,
  SEXP start_iteration) {
1318
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1319
  R_API_BEGIN();
1320
  _AssertBoosterHandleNotNull(handle);
1321
  int64_t out_len = 0;
1322
  int64_t buf_len = 1024 * 1024;
1323
  int num_iter = Rf_asInteger(num_iteration);
1324
  int start_iter = Rf_asInteger(start_iteration);
1325
  int importance_type = Rf_asInteger(feature_importance_type);
1326
  std::vector<char> inner_char_buf(buf_len);
1327
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1328
  SEXP model_str = Rf_protect(safe_R_raw(out_len, &cont_token));
1329
  // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
1330
  if (out_len > buf_len) {
1331
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
1332
1333
  } else {
    std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
1334
  }
1335
  Rf_unprotect(2);
1336
  return model_str;
1337
  R_API_END();
1338
}
1339
#endif
1340

1341
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
1342
  SEXP num_iteration,
1343
1344
  SEXP feature_importance_type,
  SEXP start_iteration) {
1345
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1346
  R_API_BEGIN();
1347
  _AssertBoosterHandleNotNull(handle);
1348
  SEXP model_str;
1349
  int64_t out_len = 0;
1350
  int64_t buf_len = 1024 * 1024;
1351
  int num_iter = Rf_asInteger(num_iteration);
1352
  int start_iter = Rf_asInteger(start_iteration);
1353
  int importance_type = Rf_asInteger(feature_importance_type);
1354
  std::vector<char> inner_char_buf(buf_len);
1355
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1356
1357
1358
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
1359
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
1360
  }
1361
  model_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1362
  SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1363
  Rf_unprotect(2);
1364
  return model_str;
1365
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1366
}
1367

1368
SEXP LGBM_DumpParamAliases_R() {
1369
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
  R_API_BEGIN();
  SEXP aliases_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_DumpParamAliases(buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_DumpParamAliases(out_len, &out_len, inner_char_buf.data()));
  }
1381
  aliases_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1382
  SET_STRING_ELT(aliases_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1383
  Rf_unprotect(2);
1384
1385
1386
1387
  return aliases_str;
  R_API_END();
}

1388
SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
1389
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  SEXP params_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
  }
1402
  params_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1403
  SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1404
  Rf_unprotect(2);
1405
1406
1407
1408
  return params_str;
  R_API_END();
}

1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
SEXP LGBM_GetMaxThreads_R(SEXP out) {
  R_API_BEGIN();
  int num_threads;
  CHECK_CALL(LGBM_GetMaxThreads(&num_threads));
  INTEGER(out)[0] = num_threads;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_SetMaxThreads_R(SEXP num_threads) {
  R_API_BEGIN();
  int new_num_threads = Rf_asInteger(num_threads);
  CHECK_CALL(LGBM_SetMaxThreads(new_num_threads));
  return R_NilValue;
  R_API_END();
}

1426
1427
// .Call() calls
static const R_CallMethodDef CallEntries[] = {
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
  {"LGBM_HandleIsNull_R"                         , (DL_FUNC) &LGBM_HandleIsNull_R                         , 1},
  {"LGBM_DatasetCreateFromFile_R"                , (DL_FUNC) &LGBM_DatasetCreateFromFile_R                , 3},
  {"LGBM_DatasetCreateFromCSC_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R                 , 8},
  {"LGBM_DatasetCreateFromMat_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromMat_R                 , 5},
  {"LGBM_DatasetGetSubset_R"                     , (DL_FUNC) &LGBM_DatasetGetSubset_R                     , 4},
  {"LGBM_DatasetSetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R               , 2},
  {"LGBM_DatasetGetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R               , 1},
  {"LGBM_DatasetSaveBinary_R"                    , (DL_FUNC) &LGBM_DatasetSaveBinary_R                    , 2},
  {"LGBM_DatasetFree_R"                          , (DL_FUNC) &LGBM_DatasetFree_R                          , 1},
  {"LGBM_DatasetSetField_R"                      , (DL_FUNC) &LGBM_DatasetSetField_R                      , 4},
  {"LGBM_DatasetGetFieldSize_R"                  , (DL_FUNC) &LGBM_DatasetGetFieldSize_R                  , 3},
  {"LGBM_DatasetGetField_R"                      , (DL_FUNC) &LGBM_DatasetGetField_R                      , 3},
  {"LGBM_DatasetUpdateParamChecking_R"           , (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R           , 2},
  {"LGBM_DatasetGetNumData_R"                    , (DL_FUNC) &LGBM_DatasetGetNumData_R                    , 2},
  {"LGBM_DatasetGetNumFeature_R"                 , (DL_FUNC) &LGBM_DatasetGetNumFeature_R                 , 2},
  {"LGBM_DatasetGetFeatureNumBin_R"              , (DL_FUNC) &LGBM_DatasetGetFeatureNumBin_R              , 3},
  {"LGBM_BoosterCreate_R"                        , (DL_FUNC) &LGBM_BoosterCreate_R                        , 2},
  {"LGBM_BoosterFree_R"                          , (DL_FUNC) &LGBM_BoosterFree_R                          , 1},
  {"LGBM_BoosterCreateFromModelfile_R"           , (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R           , 1},
  {"LGBM_BoosterLoadModelFromString_R"           , (DL_FUNC) &LGBM_BoosterLoadModelFromString_R           , 1},
  {"LGBM_BoosterMerge_R"                         , (DL_FUNC) &LGBM_BoosterMerge_R                         , 2},
  {"LGBM_BoosterAddValidData_R"                  , (DL_FUNC) &LGBM_BoosterAddValidData_R                  , 2},
  {"LGBM_BoosterResetTrainingData_R"             , (DL_FUNC) &LGBM_BoosterResetTrainingData_R             , 2},
  {"LGBM_BoosterResetParameter_R"                , (DL_FUNC) &LGBM_BoosterResetParameter_R                , 2},
  {"LGBM_BoosterGetNumClasses_R"                 , (DL_FUNC) &LGBM_BoosterGetNumClasses_R                 , 2},
  {"LGBM_BoosterGetNumFeature_R"                 , (DL_FUNC) &LGBM_BoosterGetNumFeature_R                 , 1},
1454
  {"LGBM_BoosterGetLoadedParam_R"                , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R                , 1},
1455
1456
1457
1458
  {"LGBM_BoosterUpdateOneIter_R"                 , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R                 , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R"           , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R           , 4},
  {"LGBM_BoosterRollbackOneIter_R"               , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R               , 1},
  {"LGBM_BoosterGetCurrentIteration_R"           , (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R           , 2},
1459
1460
  {"LGBM_BoosterNumModelPerIteration_R"          , (DL_FUNC) &LGBM_BoosterNumModelPerIteration_R          , 2},
  {"LGBM_BoosterNumberOfTotalModel_R"            , (DL_FUNC) &LGBM_BoosterNumberOfTotalModel_R            , 2},
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
  {"LGBM_BoosterGetUpperBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R            , 2},
  {"LGBM_BoosterGetLowerBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R            , 2},
  {"LGBM_BoosterGetEvalNames_R"                  , (DL_FUNC) &LGBM_BoosterGetEvalNames_R                  , 1},
  {"LGBM_BoosterGetEval_R"                       , (DL_FUNC) &LGBM_BoosterGetEval_R                       , 3},
  {"LGBM_BoosterGetNumPredict_R"                 , (DL_FUNC) &LGBM_BoosterGetNumPredict_R                 , 3},
  {"LGBM_BoosterGetPredict_R"                    , (DL_FUNC) &LGBM_BoosterGetPredict_R                    , 3},
  {"LGBM_BoosterPredictForFile_R"                , (DL_FUNC) &LGBM_BoosterPredictForFile_R                , 10},
  {"LGBM_BoosterCalcNumPredict_R"                , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R                , 8},
  {"LGBM_BoosterPredictForCSC_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSC_R                 , 14},
  {"LGBM_BoosterPredictForCSR_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSR_R                 , 12},
  {"LGBM_BoosterPredictForCSRSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRow_R        , 11},
  {"LGBM_BoosterPredictForCSRSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForCSRSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFast_R    , 4},
  {"LGBM_BoosterPredictSparseOutput_R"           , (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R           , 10},
  {"LGBM_BoosterPredictForMat_R"                 , (DL_FUNC) &LGBM_BoosterPredictForMat_R                 , 11},
  {"LGBM_BoosterPredictForMatSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R        , 9},
  {"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForMatSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R    , 3},
1479
1480
1481
  {"LGBM_BoosterSaveModel_R"                     , (DL_FUNC) &LGBM_BoosterSaveModel_R                     , 5},
  {"LGBM_BoosterSaveModelToString_R"             , (DL_FUNC) &LGBM_BoosterSaveModelToString_R             , 4},
  {"LGBM_BoosterDumpModel_R"                     , (DL_FUNC) &LGBM_BoosterDumpModel_R                     , 4},
1482
1483
  {"LGBM_NullBoosterHandleError_R"               , (DL_FUNC) &LGBM_NullBoosterHandleError_R               , 0},
  {"LGBM_DumpParamAliases_R"                     , (DL_FUNC) &LGBM_DumpParamAliases_R                     , 0},
1484
1485
  {"LGBM_GetMaxThreads_R"                        , (DL_FUNC) &LGBM_GetMaxThreads_R                        , 1},
  {"LGBM_SetMaxThreads_R"                        , (DL_FUNC) &LGBM_SetMaxThreads_R                        , 1},
1486
1487
1488
  {NULL, NULL, 0}
};

1489
1490
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

1491
1492
1493
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510

#ifndef LGB_NO_ALTREP
  lgb_altrepped_char_vec = R_make_altraw_class("lgb_altrepped_char_vec", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_char_vec, get_altrepped_raw_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr_or_null);

  lgb_altrepped_int_arr = R_make_altinteger_class("lgb_altrepped_int_arr", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_int_arr, get_altrepped_vec_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr_or_null);

  lgb_altrepped_dbl_arr = R_make_altreal_class("lgb_altrepped_dbl_arr", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_dbl_arr, get_altrepped_vec_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr_or_null);
#endif
1511
}