lightgbm_R.cpp 44.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
#include <string>
#include <cstdio>
21
#include <cstdlib>
22
23
24
25
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
26
#include <algorithm>
27

Guolin Ke's avatar
Guolin Ke committed
28
29
#define COL_MAJOR (0)

30
31
32
33
34
35
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);

Guolin Ke's avatar
Guolin Ke committed
36
37
38
#define R_API_BEGIN() \
  try {
#define R_API_END() } \
39
40
41
42
  catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
  catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(...) { Rf_error("unknown exception"); } \
43
  Rf_error("%s", R_errmsg_buffer); \
44
  return R_NilValue; /* <- won't be reached */
Guolin Ke's avatar
Guolin Ke committed
45
46
47

#define CHECK_CALL(x) \
  if ((x) != 0) { \
48
    throw std::runtime_error(LGBM_GetLastError()); \
Guolin Ke's avatar
Guolin Ke committed
49
50
  }

51
52
53
54
55
56
57
58
59
60
61
62
63
64
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}

void LGBM_R_save_exception_msg(const std::string &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}

SEXP wrapped_R_string(void *len) {
  return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

65
66
67
68
SEXP wrapped_R_raw(void *len) {
  return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

69
70
71
72
73
74
75
76
SEXP wrapped_R_int(void *len) {
  return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_real(void *len) {
  return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
SEXP wrapped_Rf_mkChar(void *txt) {
  return Rf_mkChar(reinterpret_cast<char*>(txt));
}

void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
  if (jump) {
    LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
    throw err;
  }
}

SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

92
93
94
95
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

96
97
98
99
100
101
102
103
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

104
105
106
107
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}

108
109
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
110

111
112
113
114
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

115
116
117
118
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

119
120
121
122
123
124
125
126
SEXP LGBM_NullBoosterHandleError_R() {
  Rf_error(
      "Attempting to use a Booster which no longer exists and/or cannot be restored. "
      "This can happen if you have called Booster$finalize() "
      "or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
  return R_NilValue;
}

127
128
void _AssertBoosterHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
129
    LGBM_NullBoosterHandleError_R();
130
131
132
133
134
135
136
137
138
139
140
141
  }
}

void _AssertDatasetHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
    Rf_error(
      "Attempting to use a Dataset which no longer exists. "
      "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
      "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
  }
}

142
143
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
144
  SEXP reference) {
145
  R_API_BEGIN();
146
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
147
  DatasetHandle handle = nullptr;
148
149
150
151
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
152
153
154
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
  CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
155
  R_SetExternalPtrAddr(ret, handle);
156
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
157
  UNPROTECT(3);
158
  return ret;
159
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
160
161
}

162
163
164
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
165
166
167
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
168
  SEXP parameters,
169
  SEXP reference) {
170
  R_API_BEGIN();
171
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
172
173
174
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
175
176
177
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
178
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
179
  DatasetHandle handle = nullptr;
180
181
182
183
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
184
185
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
186
    nrow, parameters_ptr, ref, &handle));
187
  R_SetExternalPtrAddr(ret, handle);
188
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
189
  UNPROTECT(2);
190
  return ret;
191
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
192
193
}

194
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
195
196
  SEXP num_row,
  SEXP num_col,
197
  SEXP parameters,
198
  SEXP reference) {
199
  R_API_BEGIN();
200
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
201
202
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
203
  double* p_mat = REAL(data);
204
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
205
  DatasetHandle handle = nullptr;
206
207
208
209
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
210
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
211
    parameters_ptr, ref, &handle));
212
  R_SetExternalPtrAddr(ret, handle);
213
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
214
  UNPROTECT(2);
215
  return ret;
216
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
217
218
}

219
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
220
  SEXP used_row_indices,
221
  SEXP len_used_row_indices,
222
  SEXP parameters) {
223
  R_API_BEGIN();
224
  _AssertDatasetHandleNotNull(handle);
225
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
226
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
227
  std::unique_ptr<int32_t[]> idxvec(new int32_t[len]);
228
  // convert from one-based to zero-based index
229
  const int *used_row_indices_ = INTEGER(used_row_indices);
230
231
232
#ifndef _MSC_VER
#pragma omp simd
#endif
233
  for (int32_t i = 0; i < len; ++i) {
234
    idxvec[i] = static_cast<int32_t>(used_row_indices_[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
235
  }
236
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
237
  DatasetHandle res = nullptr;
238
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
239
    idxvec.get(), len, parameters_ptr,
Guolin Ke's avatar
Guolin Ke committed
240
    &res));
241
  R_SetExternalPtrAddr(ret, res);
242
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
243
  UNPROTECT(2);
244
  return ret;
245
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
246
247
}

248
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
249
  SEXP feature_names) {
250
  R_API_BEGIN();
251
  _AssertDatasetHandleNotNull(handle);
252
  auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
Guolin Ke's avatar
Guolin Ke committed
253
  int len = static_cast<int>(vec_names.size());
254
  std::unique_ptr<const char*[]> vec_sptr(new const char*[len]);
Guolin Ke's avatar
Guolin Ke committed
255
  for (int i = 0; i < len; ++i) {
256
    vec_sptr[i] = vec_names[i].c_str();
Guolin Ke's avatar
Guolin Ke committed
257
  }
258
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
259
    vec_sptr.get(), len));
260
261
  UNPROTECT(1);
  return R_NilValue;
262
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
263
264
}

265
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
266
267
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
268
  _AssertDatasetHandleNotNull(handle);
269
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
270
  int len = 0;
271
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
272
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
273
274
275
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
276
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
277
278
279
    ptr_names[i] = names[i].data();
  }
  int out_len;
280
281
282
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
283
      R_ExternalPtrAddr(handle),
284
285
286
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
287
288
289
290
291
292
293
294
295
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
296
        R_ExternalPtrAddr(handle),
297
298
299
300
301
302
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
303
  CHECK_EQ(len, out_len);
304
  feature_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
305
  for (int i = 0; i < len; ++i) {
306
    SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
307
  }
308
  UNPROTECT(2);
309
  return feature_names;
310
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
314
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
315
  R_API_BEGIN();
316
  _AssertDatasetHandleNotNull(handle);
317
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
318
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
319
320
321
    filename_ptr));
  UNPROTECT(1);
  return R_NilValue;
322
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
323
324
}

325
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
326
  R_API_BEGIN();
327
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
328
329
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
330
  }
331
  return R_NilValue;
332
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
333
334
}

335
SEXP LGBM_DatasetSetField_R(SEXP handle,
336
  SEXP field_name,
337
  SEXP field_data,
338
  SEXP num_element) {
339
  R_API_BEGIN();
340
  _AssertDatasetHandleNotNull(handle);
341
  int len = Rf_asInteger(num_element);
342
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
343
  if (!strcmp("group", name) || !strcmp("query", name)) {
344
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, INTEGER(field_data), len, C_API_DTYPE_INT32));
345
  } else if (!strcmp("init_score", name)) {
346
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
347
  } else {
348
349
350
    std::unique_ptr<float[]> vec(new float[len]);
    std::copy(REAL(field_data), REAL(field_data) + len, vec.get());
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.get(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
351
  }
352
353
  UNPROTECT(1);
  return R_NilValue;
354
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
355
356
}

357
SEXP LGBM_DatasetGetField_R(SEXP handle,
358
  SEXP field_name,
359
  SEXP field_data) {
360
  R_API_BEGIN();
361
  _AssertDatasetHandleNotNull(handle);
362
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
363
364
365
  int out_len = 0;
  int out_type = 0;
  const void* res;
366
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
367
368
369
  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
370
    int *field_data_ = INTEGER(field_data);
371
372
373
#ifndef _MSC_VER
#pragma omp simd
#endif
Guolin Ke's avatar
Guolin Ke committed
374
    for (int i = 0; i < out_len - 1; ++i) {
375
      field_data_[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
376
    }
Guolin Ke's avatar
Guolin Ke committed
377
378
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
379
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
380
381
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
382
    std::copy(p_data, p_data + out_len, REAL(field_data));
Guolin Ke's avatar
Guolin Ke committed
383
  }
384
385
  UNPROTECT(1);
  return R_NilValue;
386
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
387
388
}

389
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
390
  SEXP field_name,
391
  SEXP out) {
392
  R_API_BEGIN();
393
  _AssertDatasetHandleNotNull(handle);
394
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
395
396
397
  int out_len = 0;
  int out_type = 0;
  const void* res;
398
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
399
400
401
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
402
  INTEGER(out)[0] = out_len;
403
404
  UNPROTECT(1);
  return R_NilValue;
405
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
406
407
}

408
409
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
410
  R_API_BEGIN();
411
412
413
414
415
  const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
  const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
  UNPROTECT(2);
  return R_NilValue;
416
  R_API_END();
417
418
}

419
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
420
  R_API_BEGIN();
421
  _AssertDatasetHandleNotNull(handle);
422
  int nrow;
423
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
424
  INTEGER(out)[0] = nrow;
425
  return R_NilValue;
426
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
427
428
}

429
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
430
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
431
  R_API_BEGIN();
432
  _AssertDatasetHandleNotNull(handle);
433
  int nfeature;
434
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
435
  INTEGER(out)[0] = nfeature;
436
  return R_NilValue;
437
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
438
439
}

440
441
442
443
444
445
446
447
448
449
450
SEXP LGBM_DatasetGetFeatureNumBin_R(SEXP handle, SEXP feature_idx, SEXP out) {
  R_API_BEGIN();
  _AssertDatasetHandleNotNull(handle);
  int feature = Rf_asInteger(feature_idx);
  int nbins;
  CHECK_CALL(LGBM_DatasetGetFeatureNumBin(R_ExternalPtrAddr(handle), feature, &nbins));
  INTEGER(out)[0] = nbins;
  return R_NilValue;
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
451
452
// --- start Booster interfaces

453
454
455
456
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

457
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
458
  R_API_BEGIN();
459
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
460
461
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
462
  }
463
  return R_NilValue;
464
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
465
466
}

467
468
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
469
  R_API_BEGIN();
470
  _AssertDatasetHandleNotNull(train_data);
471
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
472
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
473
  BoosterHandle handle = nullptr;
474
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
475
  R_SetExternalPtrAddr(ret, handle);
476
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
477
  UNPROTECT(2);
478
  return ret;
479
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
480
481
}

482
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
483
  R_API_BEGIN();
484
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
485
  int out_num_iterations = 0;
486
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
Guolin Ke's avatar
Guolin Ke committed
487
  BoosterHandle handle = nullptr;
488
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
489
  R_SetExternalPtrAddr(ret, handle);
490
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
491
  UNPROTECT(2);
492
  return ret;
493
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
494
495
}

496
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
497
  R_API_BEGIN();
498
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
499
500
  SEXP temp = NULL;
  int n_protected = 1;
501
  int out_num_iterations = 0;
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
  const char* model_str_ptr = nullptr;
  switch (TYPEOF(model_str)) {
    case RAWSXP: {
      model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
      break;
    }
    case CHARSXP: {
      model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
      break;
    }
    case STRSXP: {
      temp = PROTECT(STRING_ELT(model_str, 0));
      n_protected++;
      model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
    }
  }
Guolin Ke's avatar
Guolin Ke committed
518
  BoosterHandle handle = nullptr;
519
  CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
520
  R_SetExternalPtrAddr(ret, handle);
521
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
522
  UNPROTECT(n_protected);
523
  return ret;
524
  R_API_END();
525
526
}

527
528
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
529
  R_API_BEGIN();
530
531
  _AssertBoosterHandleNotNull(handle);
  _AssertBoosterHandleNotNull(other_handle);
532
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
533
  return R_NilValue;
534
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
535
536
}

537
538
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
539
  R_API_BEGIN();
540
541
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(valid_data);
542
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
543
  return R_NilValue;
544
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
545
546
}

547
548
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
549
  R_API_BEGIN();
550
551
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(train_data);
552
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
553
  return R_NilValue;
554
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
555
556
}

557
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
558
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
559
  R_API_BEGIN();
560
  _AssertBoosterHandleNotNull(handle);
561
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
562
563
564
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
  UNPROTECT(1);
  return R_NilValue;
565
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
566
567
}

568
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
569
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
570
  R_API_BEGIN();
571
  _AssertBoosterHandleNotNull(handle);
572
  int num_class;
573
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
574
  INTEGER(out)[0] = num_class;
575
  return R_NilValue;
576
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
577
578
}

579
580
581
582
583
584
585
586
587
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int out = 0;
  CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
  return Rf_ScalarInteger(out);
  R_API_END();
}

588
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
589
  R_API_BEGIN();
590
  _AssertBoosterHandleNotNull(handle);
591
  int is_finished = 0;
592
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
593
  return R_NilValue;
594
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
595
596
}

597
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
598
599
  SEXP grad,
  SEXP hess,
600
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
601
  R_API_BEGIN();
602
  _AssertBoosterHandleNotNull(handle);
603
  int is_finished = 0;
604
  int int_len = Rf_asInteger(len);
605
606
607
608
  std::unique_ptr<float[]> tgrad(new float[int_len]), thess(new float[int_len]);
  std::copy(REAL(grad), REAL(grad) + int_len, tgrad.get());
  std::copy(REAL(hess), REAL(hess) + int_len, thess.get());
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.get(), thess.get(), &is_finished));
609
  return R_NilValue;
610
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
611
612
}

613
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
614
  R_API_BEGIN();
615
  _AssertBoosterHandleNotNull(handle);
616
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
617
  return R_NilValue;
618
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
619
620
}

621
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
622
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
623
  R_API_BEGIN();
624
  _AssertBoosterHandleNotNull(handle);
625
  int out_iteration;
626
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
627
  INTEGER(out)[0] = out_iteration;
628
  return R_NilValue;
629
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
630
631
}

632
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
633
  SEXP out_result) {
634
  R_API_BEGIN();
635
  _AssertBoosterHandleNotNull(handle);
636
  double* ptr_ret = REAL(out_result);
637
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
638
  return R_NilValue;
639
  R_API_END();
640
641
}

642
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
643
  SEXP out_result) {
644
  R_API_BEGIN();
645
  _AssertBoosterHandleNotNull(handle);
646
  double* ptr_ret = REAL(out_result);
647
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
648
  return R_NilValue;
649
  R_API_END();
650
651
}

652
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
653
654
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
655
  _AssertBoosterHandleNotNull(handle);
656
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
657
  int len;
658
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
659
  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
660
661
662
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
663
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
664
665
    ptr_names[i] = names[i].data();
  }
666

Guolin Ke's avatar
Guolin Ke committed
667
  int out_len;
668
669
670
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
671
      R_ExternalPtrAddr(handle),
672
673
674
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
675
676
677
678
679
680
681
682
683
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
684
        R_ExternalPtrAddr(handle),
685
686
687
688
689
690
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
691
  CHECK_EQ(out_len, len);
692
  eval_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
693
  for (int i = 0; i < len; ++i) {
694
    SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
695
  }
696
  UNPROTECT(2);
697
  return eval_names;
698
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
699
700
}

701
SEXP LGBM_BoosterGetEval_R(SEXP handle,
702
  SEXP data_idx,
703
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
704
  R_API_BEGIN();
705
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
706
  int len;
707
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
708
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
709
  int out_len;
710
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
711
  CHECK_EQ(out_len, len);
712
  return R_NilValue;
713
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
714
715
}

716
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
717
  SEXP data_idx,
718
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
719
  R_API_BEGIN();
720
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
721
  int64_t len;
722
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
723
  INTEGER(out)[0] = static_cast<int>(len);
724
  return R_NilValue;
725
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
726
727
}

728
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
729
  SEXP data_idx,
730
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
731
  R_API_BEGIN();
732
  _AssertBoosterHandleNotNull(handle);
733
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
734
  int64_t out_len;
735
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
736
  return R_NilValue;
737
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
738
739
}

740
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
741
  int pred_type = C_API_PREDICT_NORMAL;
742
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
743
744
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
745
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
746
747
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
748
  if (Rf_asInteger(is_predcontrib)) {
749
750
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
751
752
753
  return pred_type;
}

754
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
755
  SEXP data_filename,
756
757
758
759
760
761
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
762
763
  SEXP parameter,
  SEXP result_filename) {
764
  R_API_BEGIN();
765
  _AssertBoosterHandleNotNull(handle);
766
767
768
  const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
769
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
770
771
772
773
774
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
    result_filename_ptr));
  UNPROTECT(3);
  return R_NilValue;
775
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
776
777
}

778
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
779
780
781
782
783
784
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
785
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
786
  R_API_BEGIN();
787
  _AssertBoosterHandleNotNull(handle);
788
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
789
  int64_t len = 0;
790
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
791
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
792
  INTEGER(out_len)[0] = static_cast<int>(len);
793
  return R_NilValue;
794
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
795
796
}

797
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
798
799
800
  SEXP indptr,
  SEXP indices,
  SEXP data,
801
802
803
804
805
806
807
808
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
809
  SEXP parameter,
810
  SEXP out_result) {
811
  R_API_BEGIN();
812
  _AssertBoosterHandleNotNull(handle);
813
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
814
  const int* p_indptr = INTEGER(indptr);
815
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
816
  const double* p_data = REAL(data);
817
818
819
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
820
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
821
  int64_t out_len;
822
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
823
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
824
825
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
826
827
828
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
829
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
830
831
}

832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
SEXP LGBM_BoosterPredictForCSR_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSR(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data), Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
  UNPROTECT(1);
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRow_R(SEXP handle,
  SEXP indices,
  SEXP data,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRow(R_ExternalPtrAddr(handle),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    2, nnz, Rf_asInteger(ncols),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, REAL(out_result)));
  UNPROTECT(1);
  return R_NilValue;
  R_API_END();
}

void LGBM_FastConfigFree_wrapped(SEXP handle) {
  LGBM_FastConfigFree(static_cast<FastConfigHandle*>(R_ExternalPtrAddr(handle)));
}

SEXP LGBM_BoosterPredictForCSRSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
  UNPROTECT(2);
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForCSRSingleRowFast_R(SEXP handle_fastConfig,
  SEXP indices,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int nnz = static_cast<int>(Rf_xlength(data));
  const int indptr[] = {0, nnz};
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSRSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    indptr, C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data),
    2, nnz,
    &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

935
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
936
  SEXP data,
937
938
939
940
941
942
943
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
944
  SEXP parameter,
945
  SEXP out_result) {
946
  R_API_BEGIN();
947
  _AssertBoosterHandleNotNull(handle);
948
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
949
950
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
951
952
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
953
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
Guolin Ke's avatar
Guolin Ke committed
954
  int64_t out_len;
955
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
956
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
957
958
959
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
960
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
961
962
}

963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
struct SparseOutputPointers {
  void* indptr;
  int32_t* indices;
  void* data;
  int indptr_type;
  int data_type;
  SparseOutputPointers(void* indptr, int32_t* indices, void* data)
  : indptr(indptr), indices(indices), data(data) {}
};

void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
  LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
  delete ptr;
}

SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP is_csr,
  SEXP nrows,
  SEXP ncols,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  const char* out_names[] = {"indptr", "indices", "data", ""};
  SEXP out = PROTECT(Rf_mkNamed(VECSXP, out_names));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));

  int64_t out_len[2];
  void *out_indptr;
  int32_t *out_indices;
  void *out_data;

  CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data),
    Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
    C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr,
    Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
    out_len, &out_indptr, &out_indices, &out_data));

  std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
    new SparseOutputPointers(
      out_indptr,
      out_indices,
      out_data),
    &delete_SparseOutputPointers
  };

  SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
  SET_VECTOR_ELT(out, 0, out_indptr_R);
  SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
  SET_VECTOR_ELT(out, 1, out_indices_R);
  SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
  SET_VECTOR_ELT(out, 2, out_data_R);
  std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
  std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
  std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));

  UNPROTECT(3);
  return out;
  R_API_END();
}

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
SEXP LGBM_BoosterPredictForMatSingleRow_R(SEXP handle,
  SEXP data,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter,
  SEXP out_result) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  double* ptr_ret = REAL(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRow(R_ExternalPtrAddr(handle),
    REAL(data), C_API_DTYPE_FLOAT64, Rf_xlength(data), 1,
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFastInit_R(SEXP handle,
  SEXP ncols,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  FastConfigHandle out_fastConfig;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFastInit(R_ExternalPtrAddr(handle),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    C_API_DTYPE_FLOAT64, Rf_asInteger(ncols),
    parameter_ptr, &out_fastConfig));
  R_SetExternalPtrAddr(ret, out_fastConfig);
  R_RegisterCFinalizerEx(ret, LGBM_FastConfigFree_wrapped, TRUE);
  UNPROTECT(2);
  return ret;
  R_API_END();
}

SEXP LGBM_BoosterPredictForMatSingleRowFast_R(SEXP handle_fastConfig,
  SEXP data,
  SEXP out_result) {
  R_API_BEGIN();
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMatSingleRowFast(R_ExternalPtrAddr(handle_fastConfig),
    REAL(data), &out_len, REAL(out_result)));
  return R_NilValue;
  R_API_END();
}

1093
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
1094
1095
  SEXP num_iteration,
  SEXP feature_importance_type,
1096
1097
  SEXP filename,
  SEXP start_iteration) {
Guolin Ke's avatar
Guolin Ke committed
1098
  R_API_BEGIN();
1099
  _AssertBoosterHandleNotNull(handle);
1100
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
1101
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
1102
1103
  UNPROTECT(1);
  return R_NilValue;
1104
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1105
1106
}

1107
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
1108
  SEXP num_iteration,
1109
1110
  SEXP feature_importance_type,
  SEXP start_iteration) {
1111
1112
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
1113
  _AssertBoosterHandleNotNull(handle);
1114
  int64_t out_len = 0;
1115
  int64_t buf_len = 1024 * 1024;
1116
  int num_iter = Rf_asInteger(num_iteration);
1117
  int start_iter = Rf_asInteger(start_iteration);
1118
  int importance_type = Rf_asInteger(feature_importance_type);
1119
  std::vector<char> inner_char_buf(buf_len);
1120
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1121
1122
  SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
  // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
1123
  if (out_len > buf_len) {
1124
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
1125
1126
  } else {
    std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
1127
  }
1128
  UNPROTECT(2);
1129
  return model_str;
1130
  R_API_END();
1131
1132
}

1133
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
1134
  SEXP num_iteration,
1135
1136
  SEXP feature_importance_type,
  SEXP start_iteration) {
1137
1138
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
1139
  _AssertBoosterHandleNotNull(handle);
1140
  SEXP model_str;
1141
  int64_t out_len = 0;
1142
  int64_t buf_len = 1024 * 1024;
1143
  int num_iter = Rf_asInteger(num_iteration);
1144
  int start_iter = Rf_asInteger(start_iteration);
1145
  int importance_type = Rf_asInteger(feature_importance_type);
1146
  std::vector<char> inner_char_buf(buf_len);
1147
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
1148
1149
1150
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
1151
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), start_iter, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
1152
  }
1153
1154
1155
  model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
1156
  return model_str;
1157
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1158
}
1159

1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
SEXP LGBM_DumpParamAliases_R() {
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
  SEXP aliases_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_DumpParamAliases(buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_DumpParamAliases(out_len, &out_len, inner_char_buf.data()));
  }
  aliases_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(aliases_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
  return aliases_str;
  R_API_END();
}

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  SEXP params_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
  }
  params_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
  return params_str;
  R_API_END();
}

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
SEXP LGBM_GetMaxThreads_R(SEXP out) {
  R_API_BEGIN();
  int num_threads;
  CHECK_CALL(LGBM_GetMaxThreads(&num_threads));
  INTEGER(out)[0] = num_threads;
  return R_NilValue;
  R_API_END();
}

SEXP LGBM_SetMaxThreads_R(SEXP num_threads) {
  R_API_BEGIN();
  int new_num_threads = Rf_asInteger(num_threads);
  CHECK_CALL(LGBM_SetMaxThreads(new_num_threads));
  return R_NilValue;
  R_API_END();
}

1218
1219
// .Call() calls
static const R_CallMethodDef CallEntries[] = {
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
  {"LGBM_HandleIsNull_R"                         , (DL_FUNC) &LGBM_HandleIsNull_R                         , 1},
  {"LGBM_DatasetCreateFromFile_R"                , (DL_FUNC) &LGBM_DatasetCreateFromFile_R                , 3},
  {"LGBM_DatasetCreateFromCSC_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R                 , 8},
  {"LGBM_DatasetCreateFromMat_R"                 , (DL_FUNC) &LGBM_DatasetCreateFromMat_R                 , 5},
  {"LGBM_DatasetGetSubset_R"                     , (DL_FUNC) &LGBM_DatasetGetSubset_R                     , 4},
  {"LGBM_DatasetSetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R               , 2},
  {"LGBM_DatasetGetFeatureNames_R"               , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R               , 1},
  {"LGBM_DatasetSaveBinary_R"                    , (DL_FUNC) &LGBM_DatasetSaveBinary_R                    , 2},
  {"LGBM_DatasetFree_R"                          , (DL_FUNC) &LGBM_DatasetFree_R                          , 1},
  {"LGBM_DatasetSetField_R"                      , (DL_FUNC) &LGBM_DatasetSetField_R                      , 4},
  {"LGBM_DatasetGetFieldSize_R"                  , (DL_FUNC) &LGBM_DatasetGetFieldSize_R                  , 3},
  {"LGBM_DatasetGetField_R"                      , (DL_FUNC) &LGBM_DatasetGetField_R                      , 3},
  {"LGBM_DatasetUpdateParamChecking_R"           , (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R           , 2},
  {"LGBM_DatasetGetNumData_R"                    , (DL_FUNC) &LGBM_DatasetGetNumData_R                    , 2},
  {"LGBM_DatasetGetNumFeature_R"                 , (DL_FUNC) &LGBM_DatasetGetNumFeature_R                 , 2},
  {"LGBM_DatasetGetFeatureNumBin_R"              , (DL_FUNC) &LGBM_DatasetGetFeatureNumBin_R              , 3},
  {"LGBM_BoosterCreate_R"                        , (DL_FUNC) &LGBM_BoosterCreate_R                        , 2},
  {"LGBM_BoosterFree_R"                          , (DL_FUNC) &LGBM_BoosterFree_R                          , 1},
  {"LGBM_BoosterCreateFromModelfile_R"           , (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R           , 1},
  {"LGBM_BoosterLoadModelFromString_R"           , (DL_FUNC) &LGBM_BoosterLoadModelFromString_R           , 1},
  {"LGBM_BoosterMerge_R"                         , (DL_FUNC) &LGBM_BoosterMerge_R                         , 2},
  {"LGBM_BoosterAddValidData_R"                  , (DL_FUNC) &LGBM_BoosterAddValidData_R                  , 2},
  {"LGBM_BoosterResetTrainingData_R"             , (DL_FUNC) &LGBM_BoosterResetTrainingData_R             , 2},
  {"LGBM_BoosterResetParameter_R"                , (DL_FUNC) &LGBM_BoosterResetParameter_R                , 2},
  {"LGBM_BoosterGetNumClasses_R"                 , (DL_FUNC) &LGBM_BoosterGetNumClasses_R                 , 2},
  {"LGBM_BoosterGetNumFeature_R"                 , (DL_FUNC) &LGBM_BoosterGetNumFeature_R                 , 1},
1246
  {"LGBM_BoosterGetLoadedParam_R"                , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R                , 1},
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
  {"LGBM_BoosterUpdateOneIter_R"                 , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R                 , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R"           , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R           , 4},
  {"LGBM_BoosterRollbackOneIter_R"               , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R               , 1},
  {"LGBM_BoosterGetCurrentIteration_R"           , (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R           , 2},
  {"LGBM_BoosterGetUpperBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R            , 2},
  {"LGBM_BoosterGetLowerBoundValue_R"            , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R            , 2},
  {"LGBM_BoosterGetEvalNames_R"                  , (DL_FUNC) &LGBM_BoosterGetEvalNames_R                  , 1},
  {"LGBM_BoosterGetEval_R"                       , (DL_FUNC) &LGBM_BoosterGetEval_R                       , 3},
  {"LGBM_BoosterGetNumPredict_R"                 , (DL_FUNC) &LGBM_BoosterGetNumPredict_R                 , 3},
  {"LGBM_BoosterGetPredict_R"                    , (DL_FUNC) &LGBM_BoosterGetPredict_R                    , 3},
  {"LGBM_BoosterPredictForFile_R"                , (DL_FUNC) &LGBM_BoosterPredictForFile_R                , 10},
  {"LGBM_BoosterCalcNumPredict_R"                , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R                , 8},
  {"LGBM_BoosterPredictForCSC_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSC_R                 , 14},
  {"LGBM_BoosterPredictForCSR_R"                 , (DL_FUNC) &LGBM_BoosterPredictForCSR_R                 , 12},
  {"LGBM_BoosterPredictForCSRSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRow_R        , 11},
  {"LGBM_BoosterPredictForCSRSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForCSRSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForCSRSingleRowFast_R    , 4},
  {"LGBM_BoosterPredictSparseOutput_R"           , (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R           , 10},
  {"LGBM_BoosterPredictForMat_R"                 , (DL_FUNC) &LGBM_BoosterPredictForMat_R                 , 11},
  {"LGBM_BoosterPredictForMatSingleRow_R"        , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRow_R        , 9},
  {"LGBM_BoosterPredictForMatSingleRowFastInit_R", (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFastInit_R, 8},
  {"LGBM_BoosterPredictForMatSingleRowFast_R"    , (DL_FUNC) &LGBM_BoosterPredictForMatSingleRowFast_R    , 3},
1269
1270
1271
  {"LGBM_BoosterSaveModel_R"                     , (DL_FUNC) &LGBM_BoosterSaveModel_R                     , 5},
  {"LGBM_BoosterSaveModelToString_R"             , (DL_FUNC) &LGBM_BoosterSaveModelToString_R             , 4},
  {"LGBM_BoosterDumpModel_R"                     , (DL_FUNC) &LGBM_BoosterDumpModel_R                     , 4},
1272
1273
  {"LGBM_NullBoosterHandleError_R"               , (DL_FUNC) &LGBM_NullBoosterHandleError_R               , 0},
  {"LGBM_DumpParamAliases_R"                     , (DL_FUNC) &LGBM_DumpParamAliases_R                     , 0},
1274
1275
  {"LGBM_GetMaxThreads_R"                        , (DL_FUNC) &LGBM_GetMaxThreads_R                        , 1},
  {"LGBM_SetMaxThreads_R"                        , (DL_FUNC) &LGBM_SetMaxThreads_R                        , 1},
1276
1277
1278
  {NULL, NULL, 0}
};

1279
1280
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

1281
1282
1283
1284
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}