lightgbm_R.cpp 52.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>
14
#include <R_ext/Altrep.h>
15

16
#ifndef R_NO_REMAP
17
#define R_NO_REMAP
18
19
20
#endif

#ifndef R_USE_C99_IN_CXX
21
#define R_USE_C99_IN_CXX
22
23
#endif

24
25
#include <R_ext/Error.h>

26
27
#include <string>
#include <cstdio>
28
#include <cstdlib>
29
30
31
32
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
33
#include <algorithm>
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <type_traits>

R_altrep_class_t lgb_altrepped_char_vec;
R_altrep_class_t lgb_altrepped_int_arr;
R_altrep_class_t lgb_altrepped_dbl_arr;

template <class T>
void delete_cpp_array(SEXP R_ptr) {
  T *ptr_to_cpp_obj = static_cast<T*>(R_ExternalPtrAddr(R_ptr));
  delete[] ptr_to_cpp_obj;
  R_ClearExternalPtr(R_ptr);
}

void delete_cpp_char_vec(SEXP R_ptr) {
  std::vector<char> *ptr_to_cpp_obj = static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_ptr));
  delete ptr_to_cpp_obj;
  R_ClearExternalPtr(R_ptr);
}

// Note: MSVC has issues with Altrep classes, so they are disabled for it.
// See: https://github.com/microsoft/LightGBM/pull/6213#issuecomment-2111025768
#ifdef _MSC_VER
#  define LGB_NO_ALTREP
#endif

#ifndef LGB_NO_ALTREP
SEXP make_altrepped_raw_vec(void *void_ptr) {
  std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
62
63
  SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  SEXP R_raw = Rf_protect(R_new_altrep(lgb_altrepped_char_vec, R_NilValue, R_NilValue));
64
65
66
67
68
69

  R_SetExternalPtrAddr(R_ptr, ptr_to_cpp_vec->get());
  R_RegisterCFinalizerEx(R_ptr, delete_cpp_char_vec, TRUE);
  ptr_to_cpp_vec->release();

  R_set_altrep_data1(R_raw, R_ptr);
70
  Rf_unprotect(2);
71
72
73
74
75
76
  return R_raw;
}
#else
SEXP make_r_raw_vec(void *void_ptr) {
  std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
  R_xlen_t len = ptr_to_cpp_vec->get()->size();
77
  SEXP out = Rf_protect(Rf_allocVector(RAWSXP, len));
78
  std::copy(ptr_to_cpp_vec->get()->begin(), ptr_to_cpp_vec->get()->end(), reinterpret_cast<char*>(RAW(out)));
79
  Rf_unprotect(1);
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  return out;
}
#define make_altrepped_raw_vec make_r_raw_vec
#endif

std::vector<char>* get_ptr_from_altrepped_raw(SEXP R_raw) {
  return static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_altrep_data1(R_raw)));
}

R_xlen_t get_altrepped_raw_len(SEXP R_raw) {
  return get_ptr_from_altrepped_raw(R_raw)->size();
}

const void* get_altrepped_raw_dataptr_or_null(SEXP R_raw) {
  return get_ptr_from_altrepped_raw(R_raw)->data();
}

void* get_altrepped_raw_dataptr(SEXP R_raw, Rboolean writeable) {
  return get_ptr_from_altrepped_raw(R_raw)->data();
}

#ifndef LGB_NO_ALTREP
template <class T>
R_altrep_class_t get_altrep_class_for_type() {
  if (std::is_same<T, double>::value) {
    return lgb_altrepped_dbl_arr;
  } else {
    return lgb_altrepped_int_arr;
  }
}
#else
template <class T>
SEXPTYPE get_sexptype_class_for_type() {
  if (std::is_same<T, double>::value) {
    return REALSXP;
  } else {
    return INTSXP;
  }
}

template <class T>
T* get_r_vec_ptr(SEXP x) {
  if (std::is_same<T, double>::value) {
    return static_cast<T*>(static_cast<void*>(REAL(x)));
  } else {
    return static_cast<T*>(static_cast<void*>(INTEGER(x)));
  }
}
#endif

template <class T>
struct arr_and_len {
  T *arr;
  int64_t len;
};

#ifndef LGB_NO_ALTREP
template <class T>
SEXP make_altrepped_vec_from_arr(void *void_ptr) {
  T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
  uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
141
142
143
  SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  SEXP R_len = Rf_protect(Rf_allocVector(REALSXP, 1));
  SEXP R_vec = Rf_protect(R_new_altrep(get_altrep_class_for_type<T>(), R_NilValue, R_NilValue));
144
145
146
147
148
149
150

  REAL(R_len)[0] = static_cast<double>(len);
  R_SetExternalPtrAddr(R_ptr, arr);
  R_RegisterCFinalizerEx(R_ptr, delete_cpp_array<T>, TRUE);

  R_set_altrep_data1(R_vec, R_ptr);
  R_set_altrep_data2(R_vec, R_len);
151
  Rf_unprotect(3);
152
153
154
155
156
157
158
  return R_vec;
}
#else
template <class T>
SEXP make_R_vec_from_arr(void *void_ptr) {
  T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
  uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
159
  SEXP out = Rf_protect(Rf_allocVector(get_sexptype_class_for_type<T>(), len));
160
  std::copy(arr, arr + len, get_r_vec_ptr<T>(out));
161
  Rf_unprotect(1);
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
  return out;
}
#define make_altrepped_vec_from_arr make_R_vec_from_arr
#endif

R_xlen_t get_altrepped_vec_len(SEXP R_vec) {
  return static_cast<R_xlen_t>(Rf_asReal(R_altrep_data2(R_vec)));
}

const void* get_altrepped_vec_dataptr_or_null(SEXP R_vec) {
  return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}

void* get_altrepped_vec_dataptr(SEXP R_vec, Rboolean writeable) {
  return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}
178

Guolin Ke's avatar
Guolin Ke committed
179
180
#define COL_MAJOR (0)

181
182
183
184
185
186
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);

Guolin Ke's avatar
Guolin Ke committed
187
188
189
#define R_API_BEGIN() \
  try {
#define R_API_END() } \
190
191
192
193
  catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
  catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(...) { Rf_error("unknown exception"); } \
194
  Rf_error("%s", R_errmsg_buffer); \
195
  return R_NilValue; /* <- won't be reached */
Guolin Ke's avatar
Guolin Ke committed
196
197
198

#define CHECK_CALL(x) \
  if ((x) != 0) { \
199
    throw std::runtime_error(LGBM_GetLastError()); \
Guolin Ke's avatar
Guolin Ke committed
200
201
  }

202
203
204
205
206
207
208
209
210
211
212
213
214
215
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}

void LGBM_R_save_exception_msg(const std::string &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}

SEXP wrapped_R_string(void *len) {
  return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

216
217
218
219
SEXP wrapped_R_raw(void *len) {
  return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

220
221
222
223
224
225
226
227
SEXP wrapped_R_int(void *len) {
  return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_real(void *len) {
  return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
SEXP wrapped_Rf_mkChar(void *txt) {
  return Rf_mkChar(reinterpret_cast<char*>(txt));
}

void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
  if (jump) {
    LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
    throw err;
  }
}

SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

243
244
245
246
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

247
248
249
250
251
252
253
254
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

255
256
257
258
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}

259
260
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
261

262
263
264
265
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

266
267
268
269
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

270
271
272
SEXP LGBM_NullBoosterHandleError_R() {
  Rf_error(
      "Attempting to use a Booster which no longer exists and/or cannot be restored. "
273
      "This can happen if the Booster's finalizer was called "
274
275
276
277
      "or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
  return R_NilValue;
}

278
279
void _AssertBoosterHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
280
    LGBM_NullBoosterHandleError_R();
281
282
283
284
285
286
287
  }
}

void _AssertDatasetHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
    Rf_error(
      "Attempting to use a Dataset which no longer exists. "
288
      "This can happen if the Dataset's finalizer was called or if this Dataset was saved with saveRDS(). "
289
290
291
292
      "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
  }
}

293
294
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
295
  SEXP reference) {
296
  R_API_BEGIN();
297
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
298
  DatasetHandle handle = nullptr;
299
300
301
302
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
303
304
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
305
  CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
306
  R_SetExternalPtrAddr(ret, handle);
307
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
308
  Rf_unprotect(3);
309
  return ret;
310
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
314
315
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
316
317
318
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
319
  SEXP parameters,
320
  SEXP reference) {
321
  R_API_BEGIN();
322
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
323
324
325
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
326
327
328
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
329
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
330
  DatasetHandle handle = nullptr;
331
332
333
334
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
335
336
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
337
    nrow, parameters_ptr, ref, &handle));
338
  R_SetExternalPtrAddr(ret, handle);
339
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
340
  Rf_unprotect(2);
341
  return ret;
342
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
343
344
}

345
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
346
347
  SEXP num_row,
  SEXP num_col,
348
  SEXP parameters,
349
  SEXP reference) {
350
  R_API_BEGIN();
351
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
352
353
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
354
  double* p_mat = REAL(data);
355
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
356
  DatasetHandle handle = nullptr;
357
358
359
360
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
361
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
362
    parameters_ptr, ref, &handle));
363
  R_SetExternalPtrAddr(ret, handle);
364
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
365
  Rf_unprotect(2);
366
  return ret;
367
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
368
369
}

370
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
371
  SEXP used_row_indices,
372
  SEXP len_used_row_indices,
373
  SEXP parameters) {
374
  R_API_BEGIN();
375
  _AssertDatasetHandleNotNull(handle);
376
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
377
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
378
  std::unique_ptr<int32_t[]> idxvec(new int32_t[len]);
379
  // convert from one-based to zero-based index
380
  const int *used_row_indices_ = INTEGER(used_row_indices);
381
382
383
#ifndef _MSC_VER
#pragma omp simd
#endif
384
  for (int32_t i = 0; i < len; ++i) {
385
    idxvec[i] = static_cast<int32_t>(used_row_indices_[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
386
  }
387
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
388
  DatasetHandle res = nullptr;
389
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
390
    idxvec.get(), len, parameters_ptr,
Guolin Ke's avatar
Guolin Ke committed
391
    &res));
392
  R_SetExternalPtrAddr(ret, res);
393
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
394
  Rf_unprotect(2);
395
  return ret;
396
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
397
398
}

399
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
400
  SEXP feature_names) {
401
  R_API_BEGIN();
402
  _AssertDatasetHandleNotNull(handle);
403
  auto vec_names = Split(CHAR(Rf_protect(Rf_asChar(feature_names))), '\t');
Guolin Ke's avatar
Guolin Ke committed
404
  int len = static_cast<int>(vec_names.size());
405
  std::unique_ptr<const char*[]> vec_sptr(new const char*[len]);
Guolin Ke's avatar
Guolin Ke committed
406
  for (int i = 0; i < len; ++i) {
407
    vec_sptr[i] = vec_names[i].c_str();
Guolin Ke's avatar
Guolin Ke committed
408
  }
409
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
410
    vec_sptr.get(), len));
411
  Rf_unprotect(1);
412
  return R_NilValue;
413
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
414
415
}

416
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
417
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
418
  R_API_BEGIN();
419
  _AssertDatasetHandleNotNull(handle);
420
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
421
  int len = 0;
422
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
423
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
424
425
426
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
427
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
428
429
430
    ptr_names[i] = names[i].data();
  }
  int out_len;
431
432
433
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
434
      R_ExternalPtrAddr(handle),
435
436
437
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
438
439
440
441
442
443
444
445
446
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
447
        R_ExternalPtrAddr(handle),
448
449
450
451
452
453
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
454
  CHECK_EQ(len, out_len);
455
  feature_names = Rf_protect(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
456
  for (int i = 0; i < len; ++i) {
457
    SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
458
  }
459
  Rf_unprotect(2);
460
  return feature_names;
461
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
462
463
}

464
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
465
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
466
  R_API_BEGIN();
467
  _AssertDatasetHandleNotNull(handle);
468
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
469
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
470
    filename_ptr));
471
  Rf_unprotect(1);
472
  return R_NilValue;
473
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
474
475
}

476
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
477
  R_API_BEGIN();
478
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
479
480
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
481
  }
482
  return R_NilValue;
483
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
484
485
}

486
SEXP LGBM_DatasetSetField_R(SEXP handle,
487
  SEXP field_name,
488
  SEXP field_data,
489
  SEXP num_element) {
490
  R_API_BEGIN();
491
  _AssertDatasetHandleNotNull(handle);
492
  int len = Rf_asInteger(num_element);
493
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
494
  if (!strcmp("group", name) || !strcmp("query", name)) {
495
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, INTEGER(field_data), len, C_API_DTYPE_INT32));
496
  } else if (!strcmp("init_score", name)) {
497
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
498
  } else {
499
500
501
    std::unique_ptr<float[]> vec(new float[len]);
    std::copy(REAL(field_data), REAL(field_data) + len, vec.get());
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.get(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
502
  }
503
  Rf_unprotect(1);
504
  return R_NilValue;
505
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
506
507
}

508
SEXP LGBM_DatasetGetField_R(SEXP handle,
509
  SEXP field_name,
510
  SEXP field_data) {
511
  R_API_BEGIN();
512
  _AssertDatasetHandleNotNull(handle);
513
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
514
515
516
  int out_len = 0;
  int out_type = 0;
  const void* res;
517
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
518
519
520
  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
521
    int *field_data_ = INTEGER(field_data);
522
523
524
#ifndef _MSC_VER
#pragma omp simd
#endif
Guolin Ke's avatar
Guolin Ke committed
525
    for (int i = 0; i < out_len - 1; ++i) {
526
      field_data_[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
527
    }
Guolin Ke's avatar
Guolin Ke committed
528
529
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
530
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
531
532
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
533
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
534
  }
535
  Rf_unprotect(1);
536
  return R_NilValue;
537
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
538
539
}

540
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
541
  SEXP field_name,
542
  SEXP out) {
543
  R_API_BEGIN();
544
  _AssertDatasetHandleNotNull(handle);
545
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
546
547
548
  int out_len = 0;
  int out_type = 0;
  const void* res;
549
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
550
551
552
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
553
  INTEGER(out)[0] = out_len;
554
  Rf_unprotect(1);
555
  return R_NilValue;
556
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
557
558
}

559
560
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
561
  R_API_BEGIN();
562
563
  const char* old_params_ptr = CHAR(Rf_protect(Rf_asChar(old_params)));
  const char* new_params_ptr = CHAR(Rf_protect(Rf_asChar(new_params)));
564
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
565
  Rf_unprotect(2);
566
  return R_NilValue;
567
  R_API_END();
568
569
}

570
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
571
  R_API_BEGIN();
572
  _AssertDatasetHandleNotNull(handle);
573
  int nrow;
574
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
575
  INTEGER(out)[0] = nrow;
576
  return R_NilValue;
577
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
578
579
}

580
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
581
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
582
  R_API_BEGIN();
583
  _AssertDatasetHandleNotNull(handle);
584
  int nfeature;
585
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
586
  INTEGER(out)[0] = nfeature;
587
  return R_NilValue;
588
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
589
590
}

591
592
593
594
595
596
597
598
599
600
601
SEXP LGBM_DatasetGetFeatureNumBin_R(SEXP handle, SEXP feature_idx, SEXP out) {
  R_API_BEGIN();
  _AssertDatasetHandleNotNull(handle);
  int feature = Rf_asInteger(feature_idx);
  int nbins;
  CHECK_CALL(LGBM_DatasetGetFeatureNumBin(R_ExternalPtrAddr(handle), feature, &nbins));
  INTEGER(out)[0] = nbins;
  return R_NilValue;
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
602
603
// --- start Booster interfaces

604
605
606
607
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

608
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
609
  R_API_BEGIN();
610
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
611
612
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
613
  }
614
  return R_NilValue;
615
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
616
617
}

618
619
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
620
  R_API_BEGIN();
621
  _AssertDatasetHandleNotNull(train_data);
622
623
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
624
  BoosterHandle handle = nullptr;
625
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
626
  R_SetExternalPtrAddr(ret, handle);
627
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
628
  Rf_unprotect(2);
629
  return ret;
630
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
631
632
}

633
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
634
  R_API_BEGIN();
635
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
636
  int out_num_iterations = 0;
637
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
Guolin Ke's avatar
Guolin Ke committed
638
  BoosterHandle handle = nullptr;
639
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
640
  R_SetExternalPtrAddr(ret, handle);
641
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
642
  Rf_unprotect(2);
643
  return ret;
644
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
645
646
}

647
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
648
  R_API_BEGIN();
649
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
650
651
  SEXP temp = NULL;
  int n_protected = 1;
652
  int out_num_iterations = 0;
653
654
655
656
657
658
659
660
661
662
663
  const char* model_str_ptr = nullptr;
  switch (TYPEOF(model_str)) {
    case RAWSXP: {
      model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
      break;
    }
    case CHARSXP: {
      model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
      break;
    }
    case STRSXP: {
664
      temp = Rf_protect(STRING_ELT(model_str, 0));
665
666
667
668
      n_protected++;
      model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
    }
  }
Guolin Ke's avatar
Guolin Ke committed
669
  BoosterHandle handle = nullptr;
670
  CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
671
  R_SetExternalPtrAddr(ret, handle);
672
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
673
  Rf_unprotect(n_protected);
674
  return ret;
675
  R_API_END();
676
677
}

678
679
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
680
  R_API_BEGIN();
681
682
  _AssertBoosterHandleNotNull(handle);
  _AssertBoosterHandleNotNull(other_handle);
683
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
684
  return R_NilValue;
685
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
686
687
}

688
689
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
690
  R_API_BEGIN();
691
692
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(valid_data);
693
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
694
  return R_NilValue;
695
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
696
697
}

698
699
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
700
  R_API_BEGIN();
701
702
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(train_data);
703
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
704
  return R_NilValue;
705
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
706
707
}

708
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
709
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
710
  R_API_BEGIN();
711
  _AssertBoosterHandleNotNull(handle);
712
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
713
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
714
  Rf_unprotect(1);
715
  return R_NilValue;
716
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
717
718
}

719
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
720
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
721
  R_API_BEGIN();
722
  _AssertBoosterHandleNotNull(handle);
723
  int num_class;
724
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
725
  INTEGER(out)[0] = num_class;
726
  return R_NilValue;
727
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
728
729
}

730
731
732
733
734
735
736
737
738
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int out = 0;
  CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
  return Rf_ScalarInteger(out);
  R_API_END();
}

739
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
740
  R_API_BEGIN();
741
  _AssertBoosterHandleNotNull(handle);
742
743
  int produced_empty_tree = 0;
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &produced_empty_tree));
744
  return R_NilValue;
745
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
746
747
}

748
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
749
750
  SEXP grad,
  SEXP hess,
751
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
752
  R_API_BEGIN();
753
  _AssertBoosterHandleNotNull(handle);
754
  int produced_empty_tree = 0;
755
  int int_len = Rf_asInteger(len);
756
757
758
  std::unique_ptr<float[]> tgrad(new float[int_len]), thess(new float[int_len]);
  std::copy(REAL(grad), REAL(grad) + int_len, tgrad.get());
  std::copy(REAL(hess), REAL(hess) + int_len, thess.get());
759
760
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.get(), thess.get(),
    &produced_empty_tree));
761
  return R_NilValue;
762
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
763
764
}

765
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
766
  R_API_BEGIN();
767
  _AssertBoosterHandleNotNull(handle);
768
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
769
  return R_NilValue;
770
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
771
772
}

773
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
774
  R_API_BEGIN();
775
  _AssertBoosterHandleNotNull(handle);
776
  int out_iteration;
777
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
778
  INTEGER(out)[0] = out_iteration;
779
  return R_NilValue;
780
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
781
782
}

783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
SEXP LGBM_BoosterNumModelPerIteration_R(SEXP handle, SEXP out) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int models_per_iter;
  CHECK_CALL(LGBM_BoosterNumModelPerIteration(R_ExternalPtrAddr(handle), &models_per_iter));
  INTEGER(out)[0] = models_per_iter;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterNumberOfTotalModel_R(SEXP handle, SEXP out) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int total_models;
  CHECK_CALL(LGBM_BoosterNumberOfTotalModel(R_ExternalPtrAddr(handle), &total_models));
  INTEGER(out)[0] = total_models;
  return R_NilValue;
  R_API_END();
}

803
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
804
  SEXP out_result) {
805
  R_API_BEGIN();
806
  _AssertBoosterHandleNotNull(handle);
807
  double* ptr_ret = REAL(out_result);
808
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
809
  return R_NilValue;
810
  R_API_END();
811
812
}

813
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
814
  SEXP out_result) {
815
  R_API_BEGIN();
816
  _AssertBoosterHandleNotNull(handle);
817
  double* ptr_ret = REAL(out_result);
818
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
819
  return R_NilValue;
820
  R_API_END();
821
822
}

823
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
824
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
825
  R_API_BEGIN();
826
  _AssertBoosterHandleNotNull(handle);
827
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
828
  int len;
829
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
830
  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
831
832
833
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
834
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
835
836
    ptr_names[i] = names[i].data();
  }
837

Guolin Ke's avatar
Guolin Ke committed
838
  int out_len;
839
840
841
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
842
      R_ExternalPtrAddr(handle),
843
844
845
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
846
847
848
849
850
851
852
853
854
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
855
        R_ExternalPtrAddr(handle),
856
857
858
859
860
861
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
862
  CHECK_EQ(out_len, len);
863
  eval_names = Rf_protect(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
864
  for (int i = 0; i < len; ++i) {
865
    SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
866
  }
867
  Rf_unprotect(2);
868
  return eval_names;
869
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
870
871
}

872
SEXP LGBM_BoosterGetEval_R(SEXP handle,
873
  SEXP data_idx,
874
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
875
  R_API_BEGIN();
876
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
877
  int len;
878
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
879
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
880
  int out_len;
881
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
882
  CHECK_EQ(out_len, len);
883
  return R_NilValue;
884
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
885
886
}

887
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
888
  SEXP data_idx,
889
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
890
  R_API_BEGIN();
891
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
892
  int64_t len;
893
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
894
  INTEGER(out)[0] = static_cast<int>(len);
895
  return R_NilValue;
896
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
897
898
}

899
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
900
  SEXP data_idx,
901
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
902
  R_API_BEGIN();
903
  _AssertBoosterHandleNotNull(handle);
904
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
905
  int64_t out_len;
906
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
907
  return R_NilValue;
908
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
909
910
}

911
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
912
  int pred_type = C_API_PREDICT_NORMAL;
913
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
914
915
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
916
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
917
918
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
919
  if (Rf_asInteger(is_predcontrib)) {
920
921
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
922
923
924
  return pred_type;
}

925
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
926
  SEXP data_filename,
927
928
929
930
931
932
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
933
934
  SEXP parameter,
  SEXP result_filename) {
935
  R_API_BEGIN();
936
  _AssertBoosterHandleNotNull(handle);
937
938
939
  const char* data_filename_ptr = CHAR(Rf_protect(Rf_asChar(data_filename)));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
  const char* result_filename_ptr = CHAR(Rf_protect(Rf_asChar(result_filename)));
940
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
941
942
943
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
    result_filename_ptr));
944
  Rf_unprotect(3);
945
  return R_NilValue;
946
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
947
948
}

949
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
950
951
952
953
954
955
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
956
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
957
  R_API_BEGIN();
958
  _AssertBoosterHandleNotNull(handle);
959
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
960
  int64_t len = 0;
961
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
962
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
963
  INTEGER(out_len)[0] = static_cast<int>(len);
964
  return R_NilValue;
965
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
966
967
}

968
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
969
970
971
  SEXP indptr,
  SEXP indices,
  SEXP data,
972
973
974
975
976
977
978
979
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
980
  SEXP parameter,
981
  SEXP out_result) {
982
  R_API_BEGIN();
983
  _AssertBoosterHandleNotNull(handle);
984
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
985
  const int* p_indptr = INTEGER(indptr);
986
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
987
  const double* p_data = REAL(data);
988
989
990
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
991
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
992
  int64_t out_len;
993
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
994
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
995
996
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
997
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
998
  Rf_unprotect(1);
999
  return R_NilValue;
1000
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1001
1002
}

1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
SEXP LGBM_BoosterPredictForCSR_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1018
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1019
1020
1021
1022
1023
1024
1025
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSR(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data), Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
1026
  Rf_unprotect(1);
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRow_R(SEXP handle,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1045
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1046
1047
1048
1049
1050
1051
1052
1053
1054
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRow(R_ExternalPtrAddr(handle),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    2, nnz, Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
1055
  Rf_unprotect(1);
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
  return R_NilValue;
  R_API_END();
}

void LGBM_FastConfigFree_wrapped(SEXP handle) {
  LGBM_FastConfigFree(static_cast<FastConfigHandle*>(R_ExternalPtrAddr(handle)));
}

SEXP LGBM_BoosterPredictForCSRSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1075
1076
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1077
1078
1079
1080
1081
1082
1083
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
1084
  Rf_unprotect(2);
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRowFast_R(SEXP handle_fastConfig,
  SEXP indices,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data),
    2, nnz,
    &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1106
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
1107
  SEXP data,
1108
1109
1110
1111
1112
1113
1114
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
1115
  SEXP parameter,
1116
  SEXP out_result) {
1117
  R_API_BEGIN();
1118
  _AssertBoosterHandleNotNull(handle);
1119
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1120
1121
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
1122
1123
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
1124
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
Guolin Ke's avatar
Guolin Ke committed
1125
  int64_t out_len;
1126
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
1127
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
1128
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
1129
  Rf_unprotect(1);
1130
  return R_NilValue;
1131
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1132
1133
}

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
struct SparseOutputPointers {
  void* indptr;
  int32_t* indices;
  void* data;
  SparseOutputPointers(void* indptr, int32_t* indices, void* data)
  : indptr(indptr), indices(indices), data(data) {}
};

void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
  LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
  delete ptr;
}

SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP is_csr,
  SEXP nrows,
  SEXP ncols,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
1157
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1158
1159
1160
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  const char* out_names[] = {"indptr", "indices", "data", ""};
1161
1162
  SEXP out = Rf_protect(Rf_mkNamed(VECSXP, out_names));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186

  int64_t out_len[2];
  void *out_indptr;
  int32_t *out_indices;
  void *out_data;

  CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data),
    Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
    C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr,
    Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
    out_len, &out_indptr, &out_indices, &out_data));

  std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
    new SparseOutputPointers(
      out_indptr,
      out_indices,
      out_data),
    &delete_SparseOutputPointers
  };

1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
  arr_and_len<int> indptr_str{static_cast<int*>(out_indptr), out_len[1]};
  SET_VECTOR_ELT(
    out, 0,
    R_UnwindProtect(make_altrepped_vec_from_arr<int>,
      static_cast<void*>(&indptr_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->indptr = nullptr;

  arr_and_len<int> indices_str{static_cast<int*>(out_indices), out_len[0]};
  SET_VECTOR_ELT(
    out, 1,
    R_UnwindProtect(make_altrepped_vec_from_arr<int>,
      static_cast<void*>(&indices_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->indices = nullptr;

  arr_and_len<double> data_str{static_cast<double*>(out_data), out_len[0]};
  SET_VECTOR_ELT(
    out, 2,
    R_UnwindProtect(make_altrepped_vec_from_arr<double>,
      static_cast<void*>(&data_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->data = nullptr;
1207

1208
  Rf_unprotect(3);
1209
1210
1211
1212
  return out;
  R_API_END();
}

1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
SEXP LGBM_BoosterPredictForMatSingleRow_R(SEXP handle,
  SEXP data,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1225
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1226
1227
1228
1229
1230
1231
  double* ptr_ret = REAL(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRow(R_ExternalPtrAddr(handle),
    REAL(data), C_API_DTYPE_FLOAT64, Rf_xlength(data), 1,
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, ptr_ret));
1232
  Rf_unprotect(1);
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1248
1249
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1250
1251
1252
1253
1254
1255
1256
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
1257
  Rf_unprotect(2);
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    REAL(data), &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1273
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
1274
1275
  SEXP num_iteration,
  SEXP feature_importance_type,
1276
1277
  SEXP filename,
  SEXP start_iteration) {
Guolin Ke's avatar
Guolin Ke committed
1278
  R_API_BEGIN();
1279
  _AssertBoosterHandleNotNull(handle);
1280
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
1281
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
1282
  Rf_unprotect(1);
1283
  return R_NilValue;
1284
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1285
1286
}

1287
1288
1289
1290
1291
1292
1293
1294
// Note: for some reason, MSVC crashes when an error is thrown here
// if the buffer variable is defined as 'std::unique_ptr<std::vector<char>>',
// but not if it is defined as '<std::vector<char>'.
#ifndef _MSC_VER
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP start_iteration) {
1295
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  int num_iter = Rf_asInteger(num_iteration);
  int start_iter = Rf_asInteger(start_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
  std::unique_ptr<std::vector<char>> inner_char_buf(new std::vector<char>(buf_len));
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf->data()));
  inner_char_buf->resize(out_len);
  if (out_len > buf_len) {
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf->data()));
  }
  SEXP out = R_UnwindProtect(make_altrepped_raw_vec, &inner_char_buf, throw_R_memerr, &cont_token, cont_token);
1310
  Rf_unprotect(1);
1311
1312
1313
1314
  return out;
  R_API_END();
}
#else
1315
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
1316
  SEXP num_iteration,
1317
1318
  SEXP feature_importance_type,
  SEXP start_iteration) {
1319
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1320
  R_API_BEGIN();
1321
  _AssertBoosterHandleNotNull(handle);
1322
  int64_t out_len = 0;
1323
  int64_t buf_len = 1024 * 1024;
1324
  int num_iter = Rf_asInteger(num_iteration);
1325
  int start_iter = Rf_asInteger(start_iteration);
1326
  int importance_type = Rf_asInteger(feature_importance_type);
1327
  std::vector<char> inner_char_buf(buf_len);
1328
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1329
  SEXP model_str = Rf_protect(safe_R_raw(out_len, &cont_token));
1330
  // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
1331
  if (out_len > buf_len) {
1332
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
1333
1334
  } else {
    std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
1335
  }
1336
  Rf_unprotect(2);
1337
  return model_str;
1338
  R_API_END();
1339
}
1340
#endif
1341

1342
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
1343
  SEXP num_iteration,
1344
1345
  SEXP feature_importance_type,
  SEXP start_iteration) {
1346
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1347
  R_API_BEGIN();
1348
  _AssertBoosterHandleNotNull(handle);
1349
  SEXP model_str;
1350
  int64_t out_len = 0;
1351
  int64_t buf_len = 1024 * 1024;
1352
  int num_iter = Rf_asInteger(num_iteration);
1353
  int start_iter = Rf_asInteger(start_iteration);
1354
  int importance_type = Rf_asInteger(feature_importance_type);
1355
  std::vector<char> inner_char_buf(buf_len);
1356
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1357
1358
1359
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
1360
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
1361
  }
1362
  model_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1363
  SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1364
  Rf_unprotect(2);
1365
  return model_str;
1366
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1367
}
1368

1369
SEXP LGBM_DumpParamAliases_R() {
1370
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
  R_API_BEGIN();
  SEXP aliases_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_DumpParamAliases(buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_DumpParamAliases(out_len, &out_len, inner_char_buf.data()));
  }
1382
  aliases_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1383
  SET_STRING_ELT(aliases_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1384
  Rf_unprotect(2);
1385
1386
1387
1388
  return aliases_str;
  R_API_END();
}

1389
SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
1390
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  SEXP params_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
  }
1403
  params_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1404
  SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1405
  Rf_unprotect(2);
1406
1407
1408
1409
  return params_str;
  R_API_END();
}

1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
SEXP LGBM_GetMaxThreads_R(SEXP out) {
  R_API_BEGIN();
  int num_threads;
  CHECK_CALL(LGBM_GetMaxThreads(&num_threads));
  INTEGER(out)[0] = num_threads;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_SetMaxThreads_R(SEXP num_threads) {
  R_API_BEGIN();
  int new_num_threads = Rf_asInteger(num_threads);
  CHECK_CALL(LGBM_SetMaxThreads(new_num_threads));
  return R_NilValue;
  R_API_END();
}

1427
1428
// .Call() calls
static const R_CallMethodDef CallEntries[] = {
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
  {"LGBM_HandleIsNull_R"                         , (DL_FUNC) &LGBM_HandleIsNull_R                         , 1},
  {"LGBM_DatasetCreateFromFile_R"                , (DL_FUNC) &LGBM_DatasetCreateFromFile_R                , 3},
  {"LGBM_DatasetCreateFromCSC_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R                 , 8},
  {"LGBM_DatasetCreateFromMat_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromMat_R                 , 5},
  {"LGBM_DatasetGetSubset_R"                     , (DL_FUNC) &LGBM_DatasetGetSubset_R                     , 4},
  {"LGBM_DatasetSetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R               , 2},
  {"LGBM_DatasetGetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R               , 1},
  {"LGBM_DatasetSaveBinary_R"                    , (DL_FUNC) &LGBM_DatasetSaveBinary_R                    , 2},
  {"LGBM_DatasetFree_R"                          , (DL_FUNC) &LGBM_DatasetFree_R                          , 1},
  {"LGBM_DatasetSetField_R"                      , (DL_FUNC) &LGBM_DatasetSetField_R                      , 4},
  {"LGBM_DatasetGetFieldSize_R"                  , (DL_FUNC) &LGBM_DatasetGetFieldSize_R                  , 3},
  {"LGBM_DatasetGetField_R"                      , (DL_FUNC) &LGBM_DatasetGetField_R                      , 3},
  {"LGBM_DatasetUpdateParamChecking_R"           , (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R           , 2},
  {"LGBM_DatasetGetNumData_R"                    , (DL_FUNC) &LGBM_DatasetGetNumData_R                    , 2},
  {"LGBM_DatasetGetNumFeature_R"                 , (DL_FUNC) &LGBM_DatasetGetNumFeature_R                 , 2},
  {"LGBM_DatasetGetFeatureNumBin_R"              , (DL_FUNC) &LGBM_DatasetGetFeatureNumBin_R              , 3},
  {"LGBM_BoosterCreate_R"                        , (DL_FUNC) &LGBM_BoosterCreate_R                        , 2},
  {"LGBM_BoosterFree_R"                          , (DL_FUNC) &LGBM_BoosterFree_R                          , 1},
  {"LGBM_BoosterCreateFromModelfile_R"           , (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R           , 1},
  {"LGBM_BoosterLoadModelFromString_R"           , (DL_FUNC) &LGBM_BoosterLoadModelFromString_R           , 1},
  {"LGBM_BoosterMerge_R"                         , (DL_FUNC) &LGBM_BoosterMerge_R                         , 2},
  {"LGBM_BoosterAddValidData_R"                  , (DL_FUNC) &LGBM_BoosterAddValidData_R                  , 2},
  {"LGBM_BoosterResetTrainingData_R"             , (DL_FUNC) &LGBM_BoosterResetTrainingData_R             , 2},
  {"LGBM_BoosterResetParameter_R"                , (DL_FUNC) &LGBM_BoosterResetParameter_R                , 2},
  {"LGBM_BoosterGetNumClasses_R"                 , (DL_FUNC) &LGBM_BoosterGetNumClasses_R                 , 2},
  {"LGBM_BoosterGetNumFeature_R"                 , (DL_FUNC) &LGBM_BoosterGetNumFeature_R                 , 1},
1455
  {"LGBM_BoosterGetLoadedParam_R"                , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R                , 1},
1456
1457
1458
1459
  {"LGBM_BoosterUpdateOneIter_R"                 , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R                 , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R"           , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R           , 4},
  {"LGBM_BoosterRollbackOneIter_R"               , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R               , 1},
  {"LGBM_BoosterGetCurrentIteration_R"           , (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R           , 2},
1460
1461
  {"LGBM_BoosterNumModelPerIteration_R"          , (DL_FUNC) &LGBM_BoosterNumModelPerIteration_R          , 2},
  {"LGBM_BoosterNumberOfTotalModel_R"            , (DL_FUNC) &LGBM_BoosterNumberOfTotalModel_R            , 2},
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
  {"LGBM_BoosterGetUpperBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R            , 2},
  {"LGBM_BoosterGetLowerBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R            , 2},
  {"LGBM_BoosterGetEvalNames_R"                  , (DL_FUNC) &LGBM_BoosterGetEvalNames_R                  , 1},
  {"LGBM_BoosterGetEval_R"                       , (DL_FUNC) &LGBM_BoosterGetEval_R                       , 3},
  {"LGBM_BoosterGetNumPredict_R"                 , (DL_FUNC) &LGBM_BoosterGetNumPredict_R                 , 3},
  {"LGBM_BoosterGetPredict_R"                    , (DL_FUNC) &LGBM_BoosterGetPredict_R                    , 3},
  {"LGBM_BoosterPredictForFile_R"                , (DL_FUNC) &LGBM_BoosterPredictForFile_R                , 10},
  {"LGBM_BoosterCalcNumPredict_R"                , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R                , 8},
  {"LGBM_BoosterPredictForCSC_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSC_R                 , 14},
  {"LGBM_BoosterPredictForCSR_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSR_R                 , 12},
  {"LGBM_BoosterPredictForCSRSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRow_R        , 11},
  {"LGBM_BoosterPredictForCSRSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForCSRSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFast_R    , 4},
  {"LGBM_BoosterPredictSparseOutput_R"           , (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R           , 10},
  {"LGBM_BoosterPredictForMat_R"                 , (DL_FUNC) &LGBM_BoosterPredictForMat_R                 , 11},
  {"LGBM_BoosterPredictForMatSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R        , 9},
  {"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForMatSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R    , 3},
1480
1481
1482
  {"LGBM_BoosterSaveModel_R"                     , (DL_FUNC) &LGBM_BoosterSaveModel_R                     , 5},
  {"LGBM_BoosterSaveModelToString_R"             , (DL_FUNC) &LGBM_BoosterSaveModelToString_R             , 4},
  {"LGBM_BoosterDumpModel_R"                     , (DL_FUNC) &LGBM_BoosterDumpModel_R                     , 4},
1483
1484
  {"LGBM_NullBoosterHandleError_R"               , (DL_FUNC) &LGBM_NullBoosterHandleError_R               , 0},
  {"LGBM_DumpParamAliases_R"                     , (DL_FUNC) &LGBM_DumpParamAliases_R                     , 0},
1485
1486
  {"LGBM_GetMaxThreads_R"                        , (DL_FUNC) &LGBM_GetMaxThreads_R                        , 1},
  {"LGBM_SetMaxThreads_R"                        , (DL_FUNC) &LGBM_SetMaxThreads_R                        , 1},
1487
1488
1489
  {NULL, NULL, 0}
};

1490
1491
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

1492
1493
1494
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511

#ifndef LGB_NO_ALTREP
  lgb_altrepped_char_vec = R_make_altraw_class("lgb_altrepped_char_vec", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_char_vec, get_altrepped_raw_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr_or_null);

  lgb_altrepped_int_arr = R_make_altinteger_class("lgb_altrepped_int_arr", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_int_arr, get_altrepped_vec_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr_or_null);

  lgb_altrepped_dbl_arr = R_make_altreal_class("lgb_altrepped_dbl_arr", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_dbl_arr, get_altrepped_vec_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr_or_null);
#endif
1512
}