lightgbm_R.cpp 52.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>
14
#include <R_ext/Altrep.h>
15

16
17
18
19
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

20
21
#include <string>
#include <cstdio>
22
#include <cstdlib>
23
24
25
26
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
27
#include <algorithm>
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <type_traits>

R_altrep_class_t lgb_altrepped_char_vec;
R_altrep_class_t lgb_altrepped_int_arr;
R_altrep_class_t lgb_altrepped_dbl_arr;

template <class T>
void delete_cpp_array(SEXP R_ptr) {
  T *ptr_to_cpp_obj = static_cast<T*>(R_ExternalPtrAddr(R_ptr));
  delete[] ptr_to_cpp_obj;
  R_ClearExternalPtr(R_ptr);
}

void delete_cpp_char_vec(SEXP R_ptr) {
  std::vector<char> *ptr_to_cpp_obj = static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_ptr));
  delete ptr_to_cpp_obj;
  R_ClearExternalPtr(R_ptr);
}

// Note: MSVC has issues with Altrep classes, so they are disabled for it.
// See: https://github.com/microsoft/LightGBM/pull/6213#issuecomment-2111025768
#ifdef _MSC_VER
#  define LGB_NO_ALTREP
#endif

#ifndef LGB_NO_ALTREP
SEXP make_altrepped_raw_vec(void *void_ptr) {
  std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
56
57
  SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  SEXP R_raw = Rf_protect(R_new_altrep(lgb_altrepped_char_vec, R_NilValue, R_NilValue));
58
59
60
61
62
63

  R_SetExternalPtrAddr(R_ptr, ptr_to_cpp_vec->get());
  R_RegisterCFinalizerEx(R_ptr, delete_cpp_char_vec, TRUE);
  ptr_to_cpp_vec->release();

  R_set_altrep_data1(R_raw, R_ptr);
64
  Rf_unprotect(2);
65
66
67
68
69
70
  return R_raw;
}
#else
SEXP make_r_raw_vec(void *void_ptr) {
  std::unique_ptr<std::vector<char>> *ptr_to_cpp_vec = static_cast<std::unique_ptr<std::vector<char>>*>(void_ptr);
  R_xlen_t len = ptr_to_cpp_vec->get()->size();
71
  SEXP out = Rf_protect(Rf_allocVector(RAWSXP, len));
72
  std::copy(ptr_to_cpp_vec->get()->begin(), ptr_to_cpp_vec->get()->end(), reinterpret_cast<char*>(RAW(out)));
73
  Rf_unprotect(1);
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  return out;
}
#define make_altrepped_raw_vec make_r_raw_vec
#endif

std::vector<char>* get_ptr_from_altrepped_raw(SEXP R_raw) {
  return static_cast<std::vector<char>*>(R_ExternalPtrAddr(R_altrep_data1(R_raw)));
}

R_xlen_t get_altrepped_raw_len(SEXP R_raw) {
  return get_ptr_from_altrepped_raw(R_raw)->size();
}

const void* get_altrepped_raw_dataptr_or_null(SEXP R_raw) {
  return get_ptr_from_altrepped_raw(R_raw)->data();
}

void* get_altrepped_raw_dataptr(SEXP R_raw, Rboolean writeable) {
  return get_ptr_from_altrepped_raw(R_raw)->data();
}

#ifndef LGB_NO_ALTREP
template <class T>
R_altrep_class_t get_altrep_class_for_type() {
  if (std::is_same<T, double>::value) {
    return lgb_altrepped_dbl_arr;
  } else {
    return lgb_altrepped_int_arr;
  }
}
#else
template <class T>
SEXPTYPE get_sexptype_class_for_type() {
  if (std::is_same<T, double>::value) {
    return REALSXP;
  } else {
    return INTSXP;
  }
}

template <class T>
T* get_r_vec_ptr(SEXP x) {
  if (std::is_same<T, double>::value) {
    return static_cast<T*>(static_cast<void*>(REAL(x)));
  } else {
    return static_cast<T*>(static_cast<void*>(INTEGER(x)));
  }
}
#endif

template <class T>
struct arr_and_len {
  T *arr;
  int64_t len;
};

#ifndef LGB_NO_ALTREP
template <class T>
SEXP make_altrepped_vec_from_arr(void *void_ptr) {
  T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
  uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
135
136
137
  SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  SEXP R_len = Rf_protect(Rf_allocVector(REALSXP, 1));
  SEXP R_vec = Rf_protect(R_new_altrep(get_altrep_class_for_type<T>(), R_NilValue, R_NilValue));
138
139
140
141
142
143
144

  REAL(R_len)[0] = static_cast<double>(len);
  R_SetExternalPtrAddr(R_ptr, arr);
  R_RegisterCFinalizerEx(R_ptr, delete_cpp_array<T>, TRUE);

  R_set_altrep_data1(R_vec, R_ptr);
  R_set_altrep_data2(R_vec, R_len);
145
  Rf_unprotect(3);
146
147
148
149
150
151
152
  return R_vec;
}
#else
template <class T>
SEXP make_R_vec_from_arr(void *void_ptr) {
  T *arr = static_cast<arr_and_len<T>*>(void_ptr)->arr;
  uint64_t len = static_cast<arr_and_len<T>*>(void_ptr)->len;
153
  SEXP out = Rf_protect(Rf_allocVector(get_sexptype_class_for_type<T>(), len));
154
  std::copy(arr, arr + len, get_r_vec_ptr<T>(out));
155
  Rf_unprotect(1);
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
  return out;
}
#define make_altrepped_vec_from_arr make_R_vec_from_arr
#endif

R_xlen_t get_altrepped_vec_len(SEXP R_vec) {
  return static_cast<R_xlen_t>(Rf_asReal(R_altrep_data2(R_vec)));
}

const void* get_altrepped_vec_dataptr_or_null(SEXP R_vec) {
  return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}

void* get_altrepped_vec_dataptr(SEXP R_vec, Rboolean writeable) {
  return R_ExternalPtrAddr(R_altrep_data1(R_vec));
}
172

Guolin Ke's avatar
Guolin Ke committed
173
174
#define COL_MAJOR (0)

175
176
177
178
179
180
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);

Guolin Ke's avatar
Guolin Ke committed
181
182
183
#define R_API_BEGIN() \
  try {
#define R_API_END() } \
184
185
186
187
  catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
  catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(...) { Rf_error("unknown exception"); } \
188
  Rf_error("%s", R_errmsg_buffer); \
189
  return R_NilValue; /* <- won't be reached */
Guolin Ke's avatar
Guolin Ke committed
190
191
192

#define CHECK_CALL(x) \
  if ((x) != 0) { \
193
    throw std::runtime_error(LGBM_GetLastError()); \
Guolin Ke's avatar
Guolin Ke committed
194
195
  }

196
197
198
199
200
201
202
203
204
205
206
207
208
209
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}

void LGBM_R_save_exception_msg(const std::string &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}

SEXP wrapped_R_string(void *len) {
  return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

210
211
212
213
SEXP wrapped_R_raw(void *len) {
  return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

214
215
216
217
218
219
220
221
SEXP wrapped_R_int(void *len) {
  return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_real(void *len) {
  return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
SEXP wrapped_Rf_mkChar(void *txt) {
  return Rf_mkChar(reinterpret_cast<char*>(txt));
}

void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
  if (jump) {
    LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
    throw err;
  }
}

SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

237
238
239
240
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

241
242
243
244
245
246
247
248
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

249
250
251
252
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}

253
254
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
255

256
257
258
259
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

260
261
262
263
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

264
265
266
267
268
269
270
271
SEXP LGBM_NullBoosterHandleError_R() {
  Rf_error(
      "Attempting to use a Booster which no longer exists and/or cannot be restored. "
      "This can happen if you have called Booster$finalize() "
      "or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
  return R_NilValue;
}

272
273
void _AssertBoosterHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
274
    LGBM_NullBoosterHandleError_R();
275
276
277
278
279
280
281
282
283
284
285
286
  }
}

void _AssertDatasetHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
    Rf_error(
      "Attempting to use a Dataset which no longer exists. "
      "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
      "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
  }
}

287
288
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
289
  SEXP reference) {
290
  R_API_BEGIN();
291
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
292
  DatasetHandle handle = nullptr;
293
294
295
296
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
297
298
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
299
  CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
300
  R_SetExternalPtrAddr(ret, handle);
301
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
302
  Rf_unprotect(3);
303
  return ret;
304
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
305
306
}

307
308
309
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
310
311
312
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
313
  SEXP parameters,
314
  SEXP reference) {
315
  R_API_BEGIN();
316
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
317
318
319
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
320
321
322
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
323
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
324
  DatasetHandle handle = nullptr;
325
326
327
328
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
329
330
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
331
    nrow, parameters_ptr, ref, &handle));
332
  R_SetExternalPtrAddr(ret, handle);
333
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
334
  Rf_unprotect(2);
335
  return ret;
336
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
337
338
}

339
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
340
341
  SEXP num_row,
  SEXP num_col,
342
  SEXP parameters,
343
  SEXP reference) {
344
  R_API_BEGIN();
345
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
346
347
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
348
  double* p_mat = REAL(data);
349
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
350
  DatasetHandle handle = nullptr;
351
352
353
354
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
355
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
356
    parameters_ptr, ref, &handle));
357
  R_SetExternalPtrAddr(ret, handle);
358
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
359
  Rf_unprotect(2);
360
  return ret;
361
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
362
363
}

364
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
365
  SEXP used_row_indices,
366
  SEXP len_used_row_indices,
367
  SEXP parameters) {
368
  R_API_BEGIN();
369
  _AssertDatasetHandleNotNull(handle);
370
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
371
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
372
  std::unique_ptr<int32_t[]> idxvec(new int32_t[len]);
373
  // convert from one-based to zero-based index
374
  const int *used_row_indices_ = INTEGER(used_row_indices);
375
376
377
#ifndef _MSC_VER
#pragma omp simd
#endif
378
  for (int32_t i = 0; i < len; ++i) {
379
    idxvec[i] = static_cast<int32_t>(used_row_indices_[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
380
  }
381
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
382
  DatasetHandle res = nullptr;
383
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
384
    idxvec.get(), len, parameters_ptr,
Guolin Ke's avatar
Guolin Ke committed
385
    &res));
386
  R_SetExternalPtrAddr(ret, res);
387
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
388
  Rf_unprotect(2);
389
  return ret;
390
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
391
392
}

393
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
394
  SEXP feature_names) {
395
  R_API_BEGIN();
396
  _AssertDatasetHandleNotNull(handle);
397
  auto vec_names = Split(CHAR(Rf_protect(Rf_asChar(feature_names))), '\t');
Guolin Ke's avatar
Guolin Ke committed
398
  int len = static_cast<int>(vec_names.size());
399
  std::unique_ptr<const char*[]> vec_sptr(new const char*[len]);
Guolin Ke's avatar
Guolin Ke committed
400
  for (int i = 0; i < len; ++i) {
401
    vec_sptr[i] = vec_names[i].c_str();
Guolin Ke's avatar
Guolin Ke committed
402
  }
403
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
404
    vec_sptr.get(), len));
405
  Rf_unprotect(1);
406
  return R_NilValue;
407
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
408
409
}

410
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
411
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
412
  R_API_BEGIN();
413
  _AssertDatasetHandleNotNull(handle);
414
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
415
  int len = 0;
416
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
417
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
418
419
420
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
421
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
422
423
424
    ptr_names[i] = names[i].data();
  }
  int out_len;
425
426
427
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
428
      R_ExternalPtrAddr(handle),
429
430
431
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
432
433
434
435
436
437
438
439
440
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
441
        R_ExternalPtrAddr(handle),
442
443
444
445
446
447
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
448
  CHECK_EQ(len, out_len);
449
  feature_names = Rf_protect(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
450
  for (int i = 0; i < len; ++i) {
451
    SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
452
  }
453
  Rf_unprotect(2);
454
  return feature_names;
455
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
456
457
}

458
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
459
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
460
  R_API_BEGIN();
461
  _AssertDatasetHandleNotNull(handle);
462
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
463
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
464
    filename_ptr));
465
  Rf_unprotect(1);
466
  return R_NilValue;
467
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
468
469
}

470
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
471
  R_API_BEGIN();
472
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
473
474
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
475
  }
476
  return R_NilValue;
477
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
478
479
}

480
SEXP LGBM_DatasetSetField_R(SEXP handle,
481
  SEXP field_name,
482
  SEXP field_data,
483
  SEXP num_element) {
484
  R_API_BEGIN();
485
  _AssertDatasetHandleNotNull(handle);
486
  int len = Rf_asInteger(num_element);
487
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
488
  if (!strcmp("group", name) || !strcmp("query", name)) {
489
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, INTEGER(field_data), len, C_API_DTYPE_INT32));
490
  } else if (!strcmp("init_score", name)) {
491
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
492
  } else {
493
494
495
    std::unique_ptr<float[]> vec(new float[len]);
    std::copy(REAL(field_data), REAL(field_data) + len, vec.get());
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.get(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
496
  }
497
  Rf_unprotect(1);
498
  return R_NilValue;
499
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
500
501
}

502
SEXP LGBM_DatasetGetField_R(SEXP handle,
503
  SEXP field_name,
504
  SEXP field_data) {
505
  R_API_BEGIN();
506
  _AssertDatasetHandleNotNull(handle);
507
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
508
509
510
  int out_len = 0;
  int out_type = 0;
  const void* res;
511
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
512
513
514
  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
515
    int *field_data_ = INTEGER(field_data);
516
517
518
#ifndef _MSC_VER
#pragma omp simd
#endif
Guolin Ke's avatar
Guolin Ke committed
519
    for (int i = 0; i < out_len - 1; ++i) {
520
      field_data_[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
521
    }
Guolin Ke's avatar
Guolin Ke committed
522
523
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
524
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
525
526
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
527
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
528
  }
529
  Rf_unprotect(1);
530
  return R_NilValue;
531
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
532
533
}

534
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
535
  SEXP field_name,
536
  SEXP out) {
537
  R_API_BEGIN();
538
  _AssertDatasetHandleNotNull(handle);
539
  const char* name = CHAR(Rf_protect(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
540
541
542
  int out_len = 0;
  int out_type = 0;
  const void* res;
543
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
544
545
546
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
547
  INTEGER(out)[0] = out_len;
548
  Rf_unprotect(1);
549
  return R_NilValue;
550
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
551
552
}

553
554
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
555
  R_API_BEGIN();
556
557
  const char* old_params_ptr = CHAR(Rf_protect(Rf_asChar(old_params)));
  const char* new_params_ptr = CHAR(Rf_protect(Rf_asChar(new_params)));
558
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
559
  Rf_unprotect(2);
560
  return R_NilValue;
561
  R_API_END();
562
563
}

564
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
565
  R_API_BEGIN();
566
  _AssertDatasetHandleNotNull(handle);
567
  int nrow;
568
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
569
  INTEGER(out)[0] = nrow;
570
  return R_NilValue;
571
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
572
573
}

574
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
575
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
576
  R_API_BEGIN();
577
  _AssertDatasetHandleNotNull(handle);
578
  int nfeature;
579
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
580
  INTEGER(out)[0] = nfeature;
581
  return R_NilValue;
582
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
583
584
}

585
586
587
588
589
590
591
592
593
594
595
SEXP LGBM_DatasetGetFeatureNumBin_R(SEXP handle, SEXP feature_idx, SEXP out) {
  R_API_BEGIN();
  _AssertDatasetHandleNotNull(handle);
  int feature = Rf_asInteger(feature_idx);
  int nbins;
  CHECK_CALL(LGBM_DatasetGetFeatureNumBin(R_ExternalPtrAddr(handle), feature, &nbins));
  INTEGER(out)[0] = nbins;
  return R_NilValue;
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
596
597
// --- start Booster interfaces

598
599
600
601
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

602
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
603
  R_API_BEGIN();
604
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
605
606
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
607
  }
608
  return R_NilValue;
609
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
610
611
}

612
613
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
614
  R_API_BEGIN();
615
  _AssertDatasetHandleNotNull(train_data);
616
617
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
618
  BoosterHandle handle = nullptr;
619
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
620
  R_SetExternalPtrAddr(ret, handle);
621
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
622
  Rf_unprotect(2);
623
  return ret;
624
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
625
626
}

627
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
628
  R_API_BEGIN();
629
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
630
  int out_num_iterations = 0;
631
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
Guolin Ke's avatar
Guolin Ke committed
632
  BoosterHandle handle = nullptr;
633
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
634
  R_SetExternalPtrAddr(ret, handle);
635
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
636
  Rf_unprotect(2);
637
  return ret;
638
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
639
640
}

641
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
642
  R_API_BEGIN();
643
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
644
645
  SEXP temp = NULL;
  int n_protected = 1;
646
  int out_num_iterations = 0;
647
648
649
650
651
652
653
654
655
656
657
  const char* model_str_ptr = nullptr;
  switch (TYPEOF(model_str)) {
    case RAWSXP: {
      model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
      break;
    }
    case CHARSXP: {
      model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
      break;
    }
    case STRSXP: {
658
      temp = Rf_protect(STRING_ELT(model_str, 0));
659
660
661
662
      n_protected++;
      model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
    }
  }
Guolin Ke's avatar
Guolin Ke committed
663
  BoosterHandle handle = nullptr;
664
  CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
665
  R_SetExternalPtrAddr(ret, handle);
666
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
667
  Rf_unprotect(n_protected);
668
  return ret;
669
  R_API_END();
670
671
}

672
673
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
674
  R_API_BEGIN();
675
676
  _AssertBoosterHandleNotNull(handle);
  _AssertBoosterHandleNotNull(other_handle);
677
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
678
  return R_NilValue;
679
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
680
681
}

682
683
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
684
  R_API_BEGIN();
685
686
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(valid_data);
687
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
688
  return R_NilValue;
689
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
690
691
}

692
693
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
694
  R_API_BEGIN();
695
696
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(train_data);
697
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
698
  return R_NilValue;
699
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
700
701
}

702
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
703
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
704
  R_API_BEGIN();
705
  _AssertBoosterHandleNotNull(handle);
706
  const char* parameters_ptr = CHAR(Rf_protect(Rf_asChar(parameters)));
707
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
708
  Rf_unprotect(1);
709
  return R_NilValue;
710
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
711
712
}

713
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
714
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
715
  R_API_BEGIN();
716
  _AssertBoosterHandleNotNull(handle);
717
  int num_class;
718
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
719
  INTEGER(out)[0] = num_class;
720
  return R_NilValue;
721
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
722
723
}

724
725
726
727
728
729
730
731
732
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int out = 0;
  CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
  return Rf_ScalarInteger(out);
  R_API_END();
}

733
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
734
  R_API_BEGIN();
735
  _AssertBoosterHandleNotNull(handle);
736
  int is_finished = 0;
737
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
738
  return R_NilValue;
739
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
740
741
}

742
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
743
744
  SEXP grad,
  SEXP hess,
745
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
746
  R_API_BEGIN();
747
  _AssertBoosterHandleNotNull(handle);
748
  int is_finished = 0;
749
  int int_len = Rf_asInteger(len);
750
751
752
753
  std::unique_ptr<float[]> tgrad(new float[int_len]), thess(new float[int_len]);
  std::copy(REAL(grad), REAL(grad) + int_len, tgrad.get());
  std::copy(REAL(hess), REAL(hess) + int_len, thess.get());
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.get(), thess.get(), &is_finished));
754
  return R_NilValue;
755
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
756
757
}

758
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
759
  R_API_BEGIN();
760
  _AssertBoosterHandleNotNull(handle);
761
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
762
  return R_NilValue;
763
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
764
765
}

766
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
767
  R_API_BEGIN();
768
  _AssertBoosterHandleNotNull(handle);
769
  int out_iteration;
770
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
771
  INTEGER(out)[0] = out_iteration;
772
  return R_NilValue;
773
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
774
775
}

776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
SEXP LGBM_BoosterNumModelPerIteration_R(SEXP handle, SEXP out) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int models_per_iter;
  CHECK_CALL(LGBM_BoosterNumModelPerIteration(R_ExternalPtrAddr(handle), &models_per_iter));
  INTEGER(out)[0] = models_per_iter;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterNumberOfTotalModel_R(SEXP handle, SEXP out) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int total_models;
  CHECK_CALL(LGBM_BoosterNumberOfTotalModel(R_ExternalPtrAddr(handle), &total_models));
  INTEGER(out)[0] = total_models;
  return R_NilValue;
  R_API_END();
}

796
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
797
  SEXP out_result) {
798
  R_API_BEGIN();
799
  _AssertBoosterHandleNotNull(handle);
800
  double* ptr_ret = REAL(out_result);
801
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
802
  return R_NilValue;
803
  R_API_END();
804
805
}

806
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
807
  SEXP out_result) {
808
  R_API_BEGIN();
809
  _AssertBoosterHandleNotNull(handle);
810
  double* ptr_ret = REAL(out_result);
811
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
812
  return R_NilValue;
813
  R_API_END();
814
815
}

816
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
817
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
818
  R_API_BEGIN();
819
  _AssertBoosterHandleNotNull(handle);
820
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
821
  int len;
822
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
823
  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
824
825
826
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
827
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
828
829
    ptr_names[i] = names[i].data();
  }
830

Guolin Ke's avatar
Guolin Ke committed
831
  int out_len;
832
833
834
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
835
      R_ExternalPtrAddr(handle),
836
837
838
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
839
840
841
842
843
844
845
846
847
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
848
        R_ExternalPtrAddr(handle),
849
850
851
852
853
854
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
855
  CHECK_EQ(out_len, len);
856
  eval_names = Rf_protect(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
857
  for (int i = 0; i < len; ++i) {
858
    SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
859
  }
860
  Rf_unprotect(2);
861
  return eval_names;
862
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
863
864
}

865
SEXP LGBM_BoosterGetEval_R(SEXP handle,
866
  SEXP data_idx,
867
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
868
  R_API_BEGIN();
869
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
870
  int len;
871
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
872
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
873
  int out_len;
874
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
875
  CHECK_EQ(out_len, len);
876
  return R_NilValue;
877
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
878
879
}

880
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
881
  SEXP data_idx,
882
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
883
  R_API_BEGIN();
884
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
885
  int64_t len;
886
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
887
  INTEGER(out)[0] = static_cast<int>(len);
888
  return R_NilValue;
889
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
890
891
}

892
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
893
  SEXP data_idx,
894
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
895
  R_API_BEGIN();
896
  _AssertBoosterHandleNotNull(handle);
897
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
898
  int64_t out_len;
899
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
900
  return R_NilValue;
901
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
902
903
}

904
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
905
  int pred_type = C_API_PREDICT_NORMAL;
906
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
907
908
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
909
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
910
911
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
912
  if (Rf_asInteger(is_predcontrib)) {
913
914
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
915
916
917
  return pred_type;
}

918
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
919
  SEXP data_filename,
920
921
922
923
924
925
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
926
927
  SEXP parameter,
  SEXP result_filename) {
928
  R_API_BEGIN();
929
  _AssertBoosterHandleNotNull(handle);
930
931
932
  const char* data_filename_ptr = CHAR(Rf_protect(Rf_asChar(data_filename)));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
  const char* result_filename_ptr = CHAR(Rf_protect(Rf_asChar(result_filename)));
933
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
934
935
936
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
    result_filename_ptr));
937
  Rf_unprotect(3);
938
  return R_NilValue;
939
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
940
941
}

942
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
943
944
945
946
947
948
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
949
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
950
  R_API_BEGIN();
951
  _AssertBoosterHandleNotNull(handle);
952
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
953
  int64_t len = 0;
954
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
955
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
956
  INTEGER(out_len)[0] = static_cast<int>(len);
957
  return R_NilValue;
958
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
959
960
}

961
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
962
963
964
  SEXP indptr,
  SEXP indices,
  SEXP data,
965
966
967
968
969
970
971
972
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
973
  SEXP parameter,
974
  SEXP out_result) {
975
  R_API_BEGIN();
976
  _AssertBoosterHandleNotNull(handle);
977
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
978
  const int* p_indptr = INTEGER(indptr);
979
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
980
  const double* p_data = REAL(data);
981
982
983
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
984
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
985
  int64_t out_len;
986
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
987
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
988
989
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
990
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
991
  Rf_unprotect(1);
992
  return R_NilValue;
993
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
994
995
}

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
SEXP LGBM_BoosterPredictForCSR_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1011
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1012
1013
1014
1015
1016
1017
1018
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSR(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data), Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
1019
  Rf_unprotect(1);
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRow_R(SEXP handle,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1038
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1039
1040
1041
1042
1043
1044
1045
1046
1047
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRow(R_ExternalPtrAddr(handle),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    2, nnz, Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
1048
  Rf_unprotect(1);
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
  return R_NilValue;
  R_API_END();
}

void LGBM_FastConfigFree_wrapped(SEXP handle) {
  LGBM_FastConfigFree(static_cast<FastConfigHandle*>(R_ExternalPtrAddr(handle)));
}

SEXP LGBM_BoosterPredictForCSRSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1068
1069
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1070
1071
1072
1073
1074
1075
1076
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
1077
  Rf_unprotect(2);
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRowFast_R(SEXP handle_fastConfig,
  SEXP indices,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data),
    2, nnz,
    &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1099
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
1100
  SEXP data,
1101
1102
1103
1104
1105
1106
1107
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
1108
  SEXP parameter,
1109
  SEXP out_result) {
1110
  R_API_BEGIN();
1111
  _AssertBoosterHandleNotNull(handle);
1112
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1113
1114
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
1115
1116
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
1117
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
Guolin Ke's avatar
Guolin Ke committed
1118
  int64_t out_len;
1119
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
1120
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
1121
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
1122
  Rf_unprotect(1);
1123
  return R_NilValue;
1124
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1125
1126
}

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
struct SparseOutputPointers {
  void* indptr;
  int32_t* indices;
  void* data;
  SparseOutputPointers(void* indptr, int32_t* indices, void* data)
  : indptr(indptr), indices(indices), data(data) {}
};

void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
  LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
  delete ptr;
}

SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP is_csr,
  SEXP nrows,
  SEXP ncols,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
1150
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1151
1152
1153
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  const char* out_names[] = {"indptr", "indices", "data", ""};
1154
1155
  SEXP out = Rf_protect(Rf_mkNamed(VECSXP, out_names));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179

  int64_t out_len[2];
  void *out_indptr;
  int32_t *out_indices;
  void *out_data;

  CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data),
    Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
    C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr,
    Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
    out_len, &out_indptr, &out_indices, &out_data));

  std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
    new SparseOutputPointers(
      out_indptr,
      out_indices,
      out_data),
    &delete_SparseOutputPointers
  };

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
  arr_and_len<int> indptr_str{static_cast<int*>(out_indptr), out_len[1]};
  SET_VECTOR_ELT(
    out, 0,
    R_UnwindProtect(make_altrepped_vec_from_arr<int>,
      static_cast<void*>(&indptr_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->indptr = nullptr;

  arr_and_len<int> indices_str{static_cast<int*>(out_indices), out_len[0]};
  SET_VECTOR_ELT(
    out, 1,
    R_UnwindProtect(make_altrepped_vec_from_arr<int>,
      static_cast<void*>(&indices_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->indices = nullptr;

  arr_and_len<double> data_str{static_cast<double*>(out_data), out_len[0]};
  SET_VECTOR_ELT(
    out, 2,
    R_UnwindProtect(make_altrepped_vec_from_arr<double>,
      static_cast<void*>(&data_str), throw_R_memerr, &cont_token, cont_token));
  pointers_struct->data = nullptr;
1200

1201
  Rf_unprotect(3);
1202
1203
1204
1205
  return out;
  R_API_END();
}

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
SEXP LGBM_BoosterPredictForMatSingleRow_R(SEXP handle,
  SEXP data,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1218
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1219
1220
1221
1222
1223
1224
  double* ptr_ret = REAL(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRow(R_ExternalPtrAddr(handle),
    REAL(data), C_API_DTYPE_FLOAT64, Rf_xlength(data), 1,
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, ptr_ret));
1225
  Rf_unprotect(1);
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
1241
1242
  SEXP ret = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(Rf_protect(Rf_asChar(parameter)));
1243
1244
1245
1246
1247
1248
1249
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
1250
  Rf_unprotect(2);
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    REAL(data), &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1266
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
1267
1268
  SEXP num_iteration,
  SEXP feature_importance_type,
1269
1270
  SEXP filename,
  SEXP start_iteration) {
Guolin Ke's avatar
Guolin Ke committed
1271
  R_API_BEGIN();
1272
  _AssertBoosterHandleNotNull(handle);
1273
  const char* filename_ptr = CHAR(Rf_protect(Rf_asChar(filename)));
1274
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
1275
  Rf_unprotect(1);
1276
  return R_NilValue;
1277
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1278
1279
}

1280
1281
1282
1283
1284
1285
1286
1287
// Note: for some reason, MSVC crashes when an error is thrown here
// if the buffer variable is defined as 'std::unique_ptr<std::vector<char>>',
// but not if it is defined as '<std::vector<char>'.
#ifndef _MSC_VER
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP start_iteration) {
1288
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  int num_iter = Rf_asInteger(num_iteration);
  int start_iter = Rf_asInteger(start_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
  std::unique_ptr<std::vector<char>> inner_char_buf(new std::vector<char>(buf_len));
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf->data()));
  inner_char_buf->resize(out_len);
  if (out_len > buf_len) {
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf->data()));
  }
  SEXP out = R_UnwindProtect(make_altrepped_raw_vec, &inner_char_buf, throw_R_memerr, &cont_token, cont_token);
1303
  Rf_unprotect(1);
1304
1305
1306
1307
  return out;
  R_API_END();
}
#else
1308
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
1309
  SEXP num_iteration,
1310
1311
  SEXP feature_importance_type,
  SEXP start_iteration) {
1312
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1313
  R_API_BEGIN();
1314
  _AssertBoosterHandleNotNull(handle);
1315
  int64_t out_len = 0;
1316
  int64_t buf_len = 1024 * 1024;
1317
  int num_iter = Rf_asInteger(num_iteration);
1318
  int start_iter = Rf_asInteger(start_iteration);
1319
  int importance_type = Rf_asInteger(feature_importance_type);
1320
  std::vector<char> inner_char_buf(buf_len);
1321
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1322
  SEXP model_str = Rf_protect(safe_R_raw(out_len, &cont_token));
1323
  // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
1324
  if (out_len > buf_len) {
1325
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
1326
1327
  } else {
    std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
1328
  }
1329
  Rf_unprotect(2);
1330
  return model_str;
1331
  R_API_END();
1332
}
1333
#endif
1334

1335
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
1336
  SEXP num_iteration,
1337
1338
  SEXP feature_importance_type,
  SEXP start_iteration) {
1339
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1340
  R_API_BEGIN();
1341
  _AssertBoosterHandleNotNull(handle);
1342
  SEXP model_str;
1343
  int64_t out_len = 0;
1344
  int64_t buf_len = 1024 * 1024;
1345
  int num_iter = Rf_asInteger(num_iteration);
1346
  int start_iter = Rf_asInteger(start_iteration);
1347
  int importance_type = Rf_asInteger(feature_importance_type);
1348
  std::vector<char> inner_char_buf(buf_len);
1349
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1350
1351
1352
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
1353
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
1354
  }
1355
  model_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1356
  SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1357
  Rf_unprotect(2);
1358
  return model_str;
1359
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1360
}
1361

1362
SEXP LGBM_DumpParamAliases_R() {
1363
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
  R_API_BEGIN();
  SEXP aliases_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_DumpParamAliases(buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_DumpParamAliases(out_len, &out_len, inner_char_buf.data()));
  }
1375
  aliases_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1376
  SET_STRING_ELT(aliases_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1377
  Rf_unprotect(2);
1378
1379
1380
1381
  return aliases_str;
  R_API_END();
}

1382
SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
1383
  SEXP cont_token = Rf_protect(R_MakeUnwindCont());
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  SEXP params_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
  }
1396
  params_str = Rf_protect(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
1397
  SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
1398
  Rf_unprotect(2);
1399
1400
1401
1402
  return params_str;
  R_API_END();
}

1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
SEXP LGBM_GetMaxThreads_R(SEXP out) {
  R_API_BEGIN();
  int num_threads;
  CHECK_CALL(LGBM_GetMaxThreads(&num_threads));
  INTEGER(out)[0] = num_threads;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_SetMaxThreads_R(SEXP num_threads) {
  R_API_BEGIN();
  int new_num_threads = Rf_asInteger(num_threads);
  CHECK_CALL(LGBM_SetMaxThreads(new_num_threads));
  return R_NilValue;
  R_API_END();
}

1420
1421
// .Call() calls
static const R_CallMethodDef CallEntries[] = {
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
  {"LGBM_HandleIsNull_R"                         , (DL_FUNC) &LGBM_HandleIsNull_R                         , 1},
  {"LGBM_DatasetCreateFromFile_R"                , (DL_FUNC) &LGBM_DatasetCreateFromFile_R                , 3},
  {"LGBM_DatasetCreateFromCSC_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R                 , 8},
  {"LGBM_DatasetCreateFromMat_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromMat_R                 , 5},
  {"LGBM_DatasetGetSubset_R"                     , (DL_FUNC) &LGBM_DatasetGetSubset_R                     , 4},
  {"LGBM_DatasetSetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R               , 2},
  {"LGBM_DatasetGetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R               , 1},
  {"LGBM_DatasetSaveBinary_R"                    , (DL_FUNC) &LGBM_DatasetSaveBinary_R                    , 2},
  {"LGBM_DatasetFree_R"                          , (DL_FUNC) &LGBM_DatasetFree_R                          , 1},
  {"LGBM_DatasetSetField_R"                      , (DL_FUNC) &LGBM_DatasetSetField_R                      , 4},
  {"LGBM_DatasetGetFieldSize_R"                  , (DL_FUNC) &LGBM_DatasetGetFieldSize_R                  , 3},
  {"LGBM_DatasetGetField_R"                      , (DL_FUNC) &LGBM_DatasetGetField_R                      , 3},
  {"LGBM_DatasetUpdateParamChecking_R"           , (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R           , 2},
  {"LGBM_DatasetGetNumData_R"                    , (DL_FUNC) &LGBM_DatasetGetNumData_R                    , 2},
  {"LGBM_DatasetGetNumFeature_R"                 , (DL_FUNC) &LGBM_DatasetGetNumFeature_R                 , 2},
  {"LGBM_DatasetGetFeatureNumBin_R"              , (DL_FUNC) &LGBM_DatasetGetFeatureNumBin_R              , 3},
  {"LGBM_BoosterCreate_R"                        , (DL_FUNC) &LGBM_BoosterCreate_R                        , 2},
  {"LGBM_BoosterFree_R"                          , (DL_FUNC) &LGBM_BoosterFree_R                          , 1},
  {"LGBM_BoosterCreateFromModelfile_R"           , (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R           , 1},
  {"LGBM_BoosterLoadModelFromString_R"           , (DL_FUNC) &LGBM_BoosterLoadModelFromString_R           , 1},
  {"LGBM_BoosterMerge_R"                         , (DL_FUNC) &LGBM_BoosterMerge_R                         , 2},
  {"LGBM_BoosterAddValidData_R"                  , (DL_FUNC) &LGBM_BoosterAddValidData_R                  , 2},
  {"LGBM_BoosterResetTrainingData_R"             , (DL_FUNC) &LGBM_BoosterResetTrainingData_R             , 2},
  {"LGBM_BoosterResetParameter_R"                , (DL_FUNC) &LGBM_BoosterResetParameter_R                , 2},
  {"LGBM_BoosterGetNumClasses_R"                 , (DL_FUNC) &LGBM_BoosterGetNumClasses_R                 , 2},
  {"LGBM_BoosterGetNumFeature_R"                 , (DL_FUNC) &LGBM_BoosterGetNumFeature_R                 , 1},
1448
  {"LGBM_BoosterGetLoadedParam_R"                , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R                , 1},
1449
1450
1451
1452
  {"LGBM_BoosterUpdateOneIter_R"                 , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R                 , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R"           , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R           , 4},
  {"LGBM_BoosterRollbackOneIter_R"               , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R               , 1},
  {"LGBM_BoosterGetCurrentIteration_R"           , (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R           , 2},
1453
1454
  {"LGBM_BoosterNumModelPerIteration_R"          , (DL_FUNC) &LGBM_BoosterNumModelPerIteration_R          , 2},
  {"LGBM_BoosterNumberOfTotalModel_R"            , (DL_FUNC) &LGBM_BoosterNumberOfTotalModel_R            , 2},
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
  {"LGBM_BoosterGetUpperBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R            , 2},
  {"LGBM_BoosterGetLowerBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R            , 2},
  {"LGBM_BoosterGetEvalNames_R"                  , (DL_FUNC) &LGBM_BoosterGetEvalNames_R                  , 1},
  {"LGBM_BoosterGetEval_R"                       , (DL_FUNC) &LGBM_BoosterGetEval_R                       , 3},
  {"LGBM_BoosterGetNumPredict_R"                 , (DL_FUNC) &LGBM_BoosterGetNumPredict_R                 , 3},
  {"LGBM_BoosterGetPredict_R"                    , (DL_FUNC) &LGBM_BoosterGetPredict_R                    , 3},
  {"LGBM_BoosterPredictForFile_R"                , (DL_FUNC) &LGBM_BoosterPredictForFile_R                , 10},
  {"LGBM_BoosterCalcNumPredict_R"                , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R                , 8},
  {"LGBM_BoosterPredictForCSC_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSC_R                 , 14},
  {"LGBM_BoosterPredictForCSR_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSR_R                 , 12},
  {"LGBM_BoosterPredictForCSRSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRow_R        , 11},
  {"LGBM_BoosterPredictForCSRSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForCSRSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFast_R    , 4},
  {"LGBM_BoosterPredictSparseOutput_R"           , (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R           , 10},
  {"LGBM_BoosterPredictForMat_R"                 , (DL_FUNC) &LGBM_BoosterPredictForMat_R                 , 11},
  {"LGBM_BoosterPredictForMatSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R        , 9},
  {"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForMatSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R    , 3},
1473
1474
1475
  {"LGBM_BoosterSaveModel_R"                     , (DL_FUNC) &LGBM_BoosterSaveModel_R                     , 5},
  {"LGBM_BoosterSaveModelToString_R"             , (DL_FUNC) &LGBM_BoosterSaveModelToString_R             , 4},
  {"LGBM_BoosterDumpModel_R"                     , (DL_FUNC) &LGBM_BoosterDumpModel_R                     , 4},
1476
1477
  {"LGBM_NullBoosterHandleError_R"               , (DL_FUNC) &LGBM_NullBoosterHandleError_R               , 0},
  {"LGBM_DumpParamAliases_R"                     , (DL_FUNC) &LGBM_DumpParamAliases_R                     , 0},
1478
1479
  {"LGBM_GetMaxThreads_R"                        , (DL_FUNC) &LGBM_GetMaxThreads_R                        , 1},
  {"LGBM_SetMaxThreads_R"                        , (DL_FUNC) &LGBM_SetMaxThreads_R                        , 1},
1480
1481
1482
  {NULL, NULL, 0}
};

1483
1484
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

1485
1486
1487
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504

#ifndef LGB_NO_ALTREP
  lgb_altrepped_char_vec = R_make_altraw_class("lgb_altrepped_char_vec", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_char_vec, get_altrepped_raw_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_char_vec, get_altrepped_raw_dataptr_or_null);

  lgb_altrepped_int_arr = R_make_altinteger_class("lgb_altrepped_int_arr", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_int_arr, get_altrepped_vec_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_int_arr, get_altrepped_vec_dataptr_or_null);

  lgb_altrepped_dbl_arr = R_make_altreal_class("lgb_altrepped_dbl_arr", "lightgbm", dll);
  R_set_altrep_Length_method(lgb_altrepped_dbl_arr, get_altrepped_vec_len);
  R_set_altvec_Dataptr_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr);
  R_set_altvec_Dataptr_or_null_method(lgb_altrepped_dbl_arr, get_altrepped_vec_dataptr_or_null);
#endif
1505
}