lightgbm_R.cpp 36.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
25
#include <algorithm>
26

Guolin Ke's avatar
Guolin Ke committed
27
28
#define COL_MAJOR (0)

29
30
31
32
33
34
#define MAX_LENGTH_ERR_MSG 1024
char R_errmsg_buffer[MAX_LENGTH_ERR_MSG];
struct LGBM_R_ErrorClass { SEXP cont_token; };
void LGBM_R_save_exception_msg(const std::exception &err);
void LGBM_R_save_exception_msg(const std::string &err);

Guolin Ke's avatar
Guolin Ke committed
35
36
37
#define R_API_BEGIN() \
  try {
#define R_API_END() } \
38
39
40
41
42
43
  catch(LGBM_R_ErrorClass &cont) { R_ContinueUnwind(cont.cont_token); } \
  catch(std::exception& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(std::string& ex) { LGBM_R_save_exception_msg(ex); } \
  catch(...) { Rf_error("unknown exception"); } \
  Rf_error(R_errmsg_buffer); \
  return R_NilValue; /* <- won't be reached */
Guolin Ke's avatar
Guolin Ke committed
44
45
46

#define CHECK_CALL(x) \
  if ((x) != 0) { \
47
    throw std::runtime_error(LGBM_GetLastError()); \
Guolin Ke's avatar
Guolin Ke committed
48
49
  }

50
51
52
53
54
55
56
57
58
59
60
61
62
63
// These are helper functions to allow doing a stack unwind
// after an R allocation error, which would trigger a long jump.
void LGBM_R_save_exception_msg(const std::exception &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.what());
}

void LGBM_R_save_exception_msg(const std::string &err) {
  std::snprintf(R_errmsg_buffer, MAX_LENGTH_ERR_MSG, "%s\n", err.c_str());
}

SEXP wrapped_R_string(void *len) {
  return Rf_allocVector(STRSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

64
65
66
67
SEXP wrapped_R_raw(void *len) {
  return Rf_allocVector(RAWSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

68
69
70
71
72
73
74
75
SEXP wrapped_R_int(void *len) {
  return Rf_allocVector(INTSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

SEXP wrapped_R_real(void *len) {
  return Rf_allocVector(REALSXP, *(reinterpret_cast<R_xlen_t*>(len)));
}

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
SEXP wrapped_Rf_mkChar(void *txt) {
  return Rf_mkChar(reinterpret_cast<char*>(txt));
}

void throw_R_memerr(void *ptr_cont_token, Rboolean jump) {
  if (jump) {
    LGBM_R_ErrorClass err{*(reinterpret_cast<SEXP*>(ptr_cont_token))};
    throw err;
  }
}

SEXP safe_R_string(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_string, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

91
92
93
94
SEXP safe_R_raw(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_raw, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

95
96
97
98
99
100
101
102
SEXP safe_R_int(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_int, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

SEXP safe_R_real(R_xlen_t len, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_R_real, reinterpret_cast<void*>(&len), throw_R_memerr, cont_token, *cont_token);
}

103
104
105
106
SEXP safe_R_mkChar(char *txt, SEXP *cont_token) {
  return R_UnwindProtect(wrapped_Rf_mkChar, reinterpret_cast<void*>(txt), throw_R_memerr, cont_token, *cont_token);
}

107
108
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
109

110
111
112
113
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

114
115
116
117
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

118
119
120
121
122
123
124
125
SEXP LGBM_NullBoosterHandleError_R() {
  Rf_error(
      "Attempting to use a Booster which no longer exists and/or cannot be restored. "
      "This can happen if you have called Booster$finalize() "
      "or if this Booster was saved through saveRDS() using 'serializable=FALSE'.");
  return R_NilValue;
}

126
127
void _AssertBoosterHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
128
    LGBM_NullBoosterHandleError_R();
129
130
131
132
133
134
135
136
137
138
139
140
  }
}

void _AssertDatasetHandleNotNull(SEXP handle) {
  if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
    Rf_error(
      "Attempting to use a Dataset which no longer exists. "
      "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
      "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
  }
}

141
142
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
143
  SEXP reference) {
144
  R_API_BEGIN();
145
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
146
  DatasetHandle handle = nullptr;
147
148
149
150
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
151
152
153
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
  CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle));
154
  R_SetExternalPtrAddr(ret, handle);
155
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
156
  UNPROTECT(3);
157
  return ret;
158
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
159
160
}

161
162
163
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
164
165
166
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
167
  SEXP parameters,
168
  SEXP reference) {
169
  R_API_BEGIN();
170
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
171
172
173
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
174
175
176
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
177
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
178
  DatasetHandle handle = nullptr;
179
180
181
182
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
183
184
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
185
    nrow, parameters_ptr, ref, &handle));
186
  R_SetExternalPtrAddr(ret, handle);
187
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
188
  UNPROTECT(2);
189
  return ret;
190
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
191
192
}

193
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
194
195
  SEXP num_row,
  SEXP num_col,
196
  SEXP parameters,
197
  SEXP reference) {
198
  R_API_BEGIN();
199
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
200
201
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
202
  double* p_mat = REAL(data);
203
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
204
  DatasetHandle handle = nullptr;
205
206
207
208
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
209
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
210
    parameters_ptr, ref, &handle));
211
  R_SetExternalPtrAddr(ret, handle);
212
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
213
  UNPROTECT(2);
214
  return ret;
215
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
216
217
}

218
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
219
  SEXP used_row_indices,
220
  SEXP len_used_row_indices,
221
  SEXP parameters) {
222
  R_API_BEGIN();
223
  _AssertDatasetHandleNotNull(handle);
224
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
225
226
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
  std::vector<int32_t> idxvec(len);
227
  // convert from one-based to zero-based index
Guolin Ke's avatar
Guolin Ke committed
228
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
229
230
  for (int32_t i = 0; i < len; ++i) {
    idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
231
  }
232
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
233
  DatasetHandle res = nullptr;
234
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
235
    idxvec.data(), len, parameters_ptr,
Guolin Ke's avatar
Guolin Ke committed
236
    &res));
237
  R_SetExternalPtrAddr(ret, res);
238
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
239
  UNPROTECT(2);
240
  return ret;
241
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
242
243
}

244
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
245
  SEXP feature_names) {
246
  R_API_BEGIN();
247
  _AssertDatasetHandleNotNull(handle);
248
  auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
Guolin Ke's avatar
Guolin Ke committed
249
250
251
252
253
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
254
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
255
    vec_sptr.data(), len));
256
257
  UNPROTECT(1);
  return R_NilValue;
258
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
259
260
}

261
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
262
263
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
264
  _AssertDatasetHandleNotNull(handle);
265
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
266
  int len = 0;
267
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
268
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
269
270
271
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
272
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
273
274
275
    ptr_names[i] = names[i].data();
  }
  int out_len;
276
277
278
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
279
      R_ExternalPtrAddr(handle),
280
281
282
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
283
284
285
286
287
288
289
290
291
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
292
        R_ExternalPtrAddr(handle),
293
294
295
296
297
298
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
299
  CHECK_EQ(len, out_len);
300
  feature_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
301
  for (int i = 0; i < len; ++i) {
302
    SET_STRING_ELT(feature_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
303
  }
304
  UNPROTECT(2);
305
  return feature_names;
306
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
307
308
}

309
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
310
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
311
  R_API_BEGIN();
312
  _AssertDatasetHandleNotNull(handle);
313
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
314
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
315
316
317
    filename_ptr));
  UNPROTECT(1);
  return R_NilValue;
318
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
319
320
}

321
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
322
  R_API_BEGIN();
323
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
324
325
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
326
  }
327
  return R_NilValue;
328
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
329
330
}

331
SEXP LGBM_DatasetSetField_R(SEXP handle,
332
  SEXP field_name,
333
  SEXP field_data,
334
  SEXP num_element) {
335
  R_API_BEGIN();
336
  _AssertDatasetHandleNotNull(handle);
337
  int len = Rf_asInteger(num_element);
338
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
339
340
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
341
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
342
    for (int i = 0; i < len; ++i) {
343
      vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
344
    }
345
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_INT32));
346
  } else if (!strcmp("init_score", name)) {
347
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
348
349
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
350
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
351
    for (int i = 0; i < len; ++i) {
352
      vec[i] = static_cast<float>(REAL(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
353
    }
354
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
355
  }
356
357
  UNPROTECT(1);
  return R_NilValue;
358
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
359
360
}

361
SEXP LGBM_DatasetGetField_R(SEXP handle,
362
  SEXP field_name,
363
  SEXP field_data) {
364
  R_API_BEGIN();
365
  _AssertDatasetHandleNotNull(handle);
366
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
367
368
369
  int out_len = 0;
  int out_type = 0;
  const void* res;
370
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
371
372
373
  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
374
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
375
    for (int i = 0; i < out_len - 1; ++i) {
376
      INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
377
    }
Guolin Ke's avatar
Guolin Ke committed
378
379
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
380
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
381
    for (int i = 0; i < out_len; ++i) {
382
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
383
    }
Guolin Ke's avatar
Guolin Ke committed
384
385
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
386
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
387
    for (int i = 0; i < out_len; ++i) {
388
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
389
390
    }
  }
391
392
  UNPROTECT(1);
  return R_NilValue;
393
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
394
395
}

396
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
397
  SEXP field_name,
398
  SEXP out) {
399
  R_API_BEGIN();
400
  _AssertDatasetHandleNotNull(handle);
401
  const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
Guolin Ke's avatar
Guolin Ke committed
402
403
404
  int out_len = 0;
  int out_type = 0;
  const void* res;
405
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
406
407
408
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
409
  INTEGER(out)[0] = out_len;
410
411
  UNPROTECT(1);
  return R_NilValue;
412
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
413
414
}

415
416
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
417
  R_API_BEGIN();
418
419
420
421
422
  const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params)));
  const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params)));
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr));
  UNPROTECT(2);
  return R_NilValue;
423
  R_API_END();
424
425
}

426
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
427
  R_API_BEGIN();
428
  _AssertDatasetHandleNotNull(handle);
429
  int nrow;
430
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
431
  INTEGER(out)[0] = nrow;
432
  return R_NilValue;
433
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
434
435
}

436
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
437
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
438
  R_API_BEGIN();
439
  _AssertDatasetHandleNotNull(handle);
440
  int nfeature;
441
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
442
  INTEGER(out)[0] = nfeature;
443
  return R_NilValue;
444
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
445
446
}

447
448
449
450
451
452
453
454
455
456
457
SEXP LGBM_DatasetGetFeatureNumBin_R(SEXP handle, SEXP feature_idx, SEXP out) {
  R_API_BEGIN();
  _AssertDatasetHandleNotNull(handle);
  int feature = Rf_asInteger(feature_idx);
  int nbins;
  CHECK_CALL(LGBM_DatasetGetFeatureNumBin(R_ExternalPtrAddr(handle), feature, &nbins));
  INTEGER(out)[0] = nbins;
  return R_NilValue;
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
458
459
// --- start Booster interfaces

460
461
462
463
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

464
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
465
  R_API_BEGIN();
466
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
467
468
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
469
  }
470
  return R_NilValue;
471
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
472
473
}

474
475
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
476
  R_API_BEGIN();
477
  _AssertDatasetHandleNotNull(train_data);
478
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
479
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
Guolin Ke's avatar
Guolin Ke committed
480
  BoosterHandle handle = nullptr;
481
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle));
482
  R_SetExternalPtrAddr(ret, handle);
483
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
484
  UNPROTECT(2);
485
  return ret;
486
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
487
488
}

489
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
490
  R_API_BEGIN();
491
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
Guolin Ke's avatar
Guolin Ke committed
492
  int out_num_iterations = 0;
493
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
Guolin Ke's avatar
Guolin Ke committed
494
  BoosterHandle handle = nullptr;
495
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle));
496
  R_SetExternalPtrAddr(ret, handle);
497
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
498
  UNPROTECT(2);
499
  return ret;
500
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
501
502
}

503
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
504
  R_API_BEGIN();
505
  SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
506
507
  SEXP temp = NULL;
  int n_protected = 1;
508
  int out_num_iterations = 0;
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
  const char* model_str_ptr = nullptr;
  switch (TYPEOF(model_str)) {
    case RAWSXP: {
      model_str_ptr = reinterpret_cast<const char*>(RAW(model_str));
      break;
    }
    case CHARSXP: {
      model_str_ptr = reinterpret_cast<const char*>(CHAR(model_str));
      break;
    }
    case STRSXP: {
      temp = PROTECT(STRING_ELT(model_str, 0));
      n_protected++;
      model_str_ptr = reinterpret_cast<const char*>(CHAR(temp));
    }
  }
Guolin Ke's avatar
Guolin Ke committed
525
  BoosterHandle handle = nullptr;
526
  CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle));
527
  R_SetExternalPtrAddr(ret, handle);
528
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
529
  UNPROTECT(n_protected);
530
  return ret;
531
  R_API_END();
532
533
}

534
535
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
536
  R_API_BEGIN();
537
538
  _AssertBoosterHandleNotNull(handle);
  _AssertBoosterHandleNotNull(other_handle);
539
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
540
  return R_NilValue;
541
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
542
543
}

544
545
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
546
  R_API_BEGIN();
547
548
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(valid_data);
549
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
550
  return R_NilValue;
551
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
552
553
}

554
555
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
556
  R_API_BEGIN();
557
558
  _AssertBoosterHandleNotNull(handle);
  _AssertDatasetHandleNotNull(train_data);
559
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
560
  return R_NilValue;
561
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
562
563
}

564
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
565
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
566
  R_API_BEGIN();
567
  _AssertBoosterHandleNotNull(handle);
568
  const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
569
570
571
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
  UNPROTECT(1);
  return R_NilValue;
572
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
573
574
}

575
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
576
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
577
  R_API_BEGIN();
578
  _AssertBoosterHandleNotNull(handle);
579
  int num_class;
580
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
581
  INTEGER(out)[0] = num_class;
582
  return R_NilValue;
583
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
584
585
}

586
587
588
589
590
591
592
593
594
SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  int out = 0;
  CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
  return Rf_ScalarInteger(out);
  R_API_END();
}

595
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
596
  R_API_BEGIN();
597
  _AssertBoosterHandleNotNull(handle);
598
  int is_finished = 0;
599
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
600
  return R_NilValue;
601
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
602
603
}

604
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
605
606
  SEXP grad,
  SEXP hess,
607
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
608
  R_API_BEGIN();
609
  _AssertBoosterHandleNotNull(handle);
610
  int is_finished = 0;
611
  int int_len = Rf_asInteger(len);
Guolin Ke's avatar
Guolin Ke committed
612
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
613
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
614
  for (int j = 0; j < int_len; ++j) {
615
616
    tgrad[j] = static_cast<float>(REAL(grad)[j]);
    thess[j] = static_cast<float>(REAL(hess)[j]);
Guolin Ke's avatar
Guolin Ke committed
617
  }
618
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
619
  return R_NilValue;
620
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
621
622
}

623
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
624
  R_API_BEGIN();
625
  _AssertBoosterHandleNotNull(handle);
626
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
627
  return R_NilValue;
628
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
629
630
}

631
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
632
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
633
  R_API_BEGIN();
634
  _AssertBoosterHandleNotNull(handle);
635
  int out_iteration;
636
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
637
  INTEGER(out)[0] = out_iteration;
638
  return R_NilValue;
639
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
640
641
}

642
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
643
  SEXP out_result) {
644
  R_API_BEGIN();
645
  _AssertBoosterHandleNotNull(handle);
646
  double* ptr_ret = REAL(out_result);
647
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
648
  return R_NilValue;
649
  R_API_END();
650
651
}

652
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
653
  SEXP out_result) {
654
  R_API_BEGIN();
655
  _AssertBoosterHandleNotNull(handle);
656
  double* ptr_ret = REAL(out_result);
657
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
658
  return R_NilValue;
659
  R_API_END();
660
661
}

662
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
663
664
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
665
  _AssertBoosterHandleNotNull(handle);
666
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
667
  int len;
668
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
669
  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
670
671
672
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
673
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
674
675
    ptr_names[i] = names[i].data();
  }
676

Guolin Ke's avatar
Guolin Ke committed
677
  int out_len;
678
679
680
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
681
      R_ExternalPtrAddr(handle),
682
683
684
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
685
686
687
688
689
690
691
692
693
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
694
        R_ExternalPtrAddr(handle),
695
696
697
698
699
700
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
701
  CHECK_EQ(out_len, len);
702
  eval_names = PROTECT(safe_R_string(static_cast<R_xlen_t>(len), &cont_token));
703
  for (int i = 0; i < len; ++i) {
704
    SET_STRING_ELT(eval_names, i, safe_R_mkChar(ptr_names[i], &cont_token));
705
  }
706
  UNPROTECT(2);
707
  return eval_names;
708
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
709
710
}

711
SEXP LGBM_BoosterGetEval_R(SEXP handle,
712
  SEXP data_idx,
713
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
714
  R_API_BEGIN();
715
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
716
  int len;
717
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
718
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
719
  int out_len;
720
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
721
  CHECK_EQ(out_len, len);
722
  return R_NilValue;
723
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
724
725
}

726
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
727
  SEXP data_idx,
728
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
729
  R_API_BEGIN();
730
  _AssertBoosterHandleNotNull(handle);
Guolin Ke's avatar
Guolin Ke committed
731
  int64_t len;
732
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
733
  INTEGER(out)[0] = static_cast<int>(len);
734
  return R_NilValue;
735
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
736
737
}

738
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
739
  SEXP data_idx,
740
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
741
  R_API_BEGIN();
742
  _AssertBoosterHandleNotNull(handle);
743
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
744
  int64_t out_len;
745
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
746
  return R_NilValue;
747
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
748
749
}

750
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
751
  int pred_type = C_API_PREDICT_NORMAL;
752
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
753
754
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
755
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
756
757
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
758
  if (Rf_asInteger(is_predcontrib)) {
759
760
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
761
762
763
  return pred_type;
}

764
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
765
  SEXP data_filename,
766
767
768
769
770
771
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
772
773
  SEXP parameter,
  SEXP result_filename) {
774
  R_API_BEGIN();
775
  _AssertBoosterHandleNotNull(handle);
776
777
778
  const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
  const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
779
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
780
781
782
783
784
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr,
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr,
    result_filename_ptr));
  UNPROTECT(3);
  return R_NilValue;
785
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
786
787
}

788
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
789
790
791
792
793
794
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
795
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
796
  R_API_BEGIN();
797
  _AssertBoosterHandleNotNull(handle);
798
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
799
  int64_t len = 0;
800
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
801
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
802
  INTEGER(out_len)[0] = static_cast<int>(len);
803
  return R_NilValue;
804
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
805
806
}

807
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
808
809
810
  SEXP indptr,
  SEXP indices,
  SEXP data,
811
812
813
814
815
816
817
818
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
819
  SEXP parameter,
820
  SEXP out_result) {
821
  R_API_BEGIN();
822
  _AssertBoosterHandleNotNull(handle);
823
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
824
  const int* p_indptr = INTEGER(indptr);
825
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
826
  const double* p_data = REAL(data);
827
828
829
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
830
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
831
  int64_t out_len;
832
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
833
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
834
835
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
836
837
838
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
839
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
840
841
}

842
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
843
  SEXP data,
844
845
846
847
848
849
850
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
851
  SEXP parameter,
852
  SEXP out_result) {
853
  R_API_BEGIN();
854
  _AssertBoosterHandleNotNull(handle);
855
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
856
857
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
858
859
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
860
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
Guolin Ke's avatar
Guolin Ke committed
861
  int64_t out_len;
862
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
863
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
864
865
866
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret));
  UNPROTECT(1);
  return R_NilValue;
867
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
868
869
}

870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
struct SparseOutputPointers {
  void* indptr;
  int32_t* indices;
  void* data;
  int indptr_type;
  int data_type;
  SparseOutputPointers(void* indptr, int32_t* indices, void* data)
  : indptr(indptr), indices(indices), data(data) {}
};

void delete_SparseOutputPointers(SparseOutputPointers *ptr) {
  LGBM_BoosterFreePredictSparse(ptr->indptr, ptr->indices, ptr->data, C_API_DTYPE_INT32, C_API_DTYPE_FLOAT64);
  delete ptr;
}

SEXP LGBM_BoosterPredictSparseOutput_R(SEXP handle,
  SEXP indptr,
  SEXP indices,
  SEXP data,
  SEXP is_csr,
  SEXP nrows,
  SEXP ncols,
  SEXP start_iteration,
  SEXP num_iteration,
  SEXP parameter) {
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
  _AssertBoosterHandleNotNull(handle);
  const char* out_names[] = {"indptr", "indices", "data", ""};
  SEXP out = PROTECT(Rf_mkNamed(VECSXP, out_names));
  const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));

  int64_t out_len[2];
  void *out_indptr;
  int32_t *out_indices;
  void *out_data;

  CHECK_CALL(LGBM_BoosterPredictSparseOutput(R_ExternalPtrAddr(handle),
    INTEGER(indptr), C_API_DTYPE_INT32, INTEGER(indices),
    REAL(data), C_API_DTYPE_FLOAT64,
    Rf_xlength(indptr), Rf_xlength(data),
    Rf_asLogical(is_csr)? Rf_asInteger(ncols) : Rf_asInteger(nrows),
    C_API_PREDICT_CONTRIB, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration),
    parameter_ptr,
    Rf_asLogical(is_csr)? C_API_MATRIX_TYPE_CSR : C_API_MATRIX_TYPE_CSC,
    out_len, &out_indptr, &out_indices, &out_data));

  std::unique_ptr<SparseOutputPointers, decltype(&delete_SparseOutputPointers)> pointers_struct = {
    new SparseOutputPointers(
      out_indptr,
      out_indices,
      out_data),
    &delete_SparseOutputPointers
  };

  SEXP out_indptr_R = safe_R_int(out_len[1], &cont_token);
  SET_VECTOR_ELT(out, 0, out_indptr_R);
  SEXP out_indices_R = safe_R_int(out_len[0], &cont_token);
  SET_VECTOR_ELT(out, 1, out_indices_R);
  SEXP out_data_R = safe_R_real(out_len[0], &cont_token);
  SET_VECTOR_ELT(out, 2, out_data_R);
  std::memcpy(INTEGER(out_indptr_R), out_indptr, out_len[1]*sizeof(int));
  std::memcpy(INTEGER(out_indices_R), out_indices, out_len[0]*sizeof(int));
  std::memcpy(REAL(out_data_R), out_data, out_len[0]*sizeof(double));

  UNPROTECT(3);
  return out;
  R_API_END();
}

940
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
941
942
  SEXP num_iteration,
  SEXP feature_importance_type,
943
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
944
  R_API_BEGIN();
945
  _AssertBoosterHandleNotNull(handle);
946
  const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
947
948
949
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
  UNPROTECT(1);
  return R_NilValue;
950
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
951
952
}

953
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
954
  SEXP num_iteration,
955
  SEXP feature_importance_type) {
956
957
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
958
  _AssertBoosterHandleNotNull(handle);
959
  int64_t out_len = 0;
960
  int64_t buf_len = 1024 * 1024;
961
962
  int num_iter = Rf_asInteger(num_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
963
  std::vector<char> inner_char_buf(buf_len);
964
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
965
966
  SEXP model_str = PROTECT(safe_R_raw(out_len, &cont_token));
  // if the model string was larger than the initial buffer, call the function again, writing directly to the R object
967
  if (out_len > buf_len) {
968
969
970
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, reinterpret_cast<char*>(RAW(model_str))));
  } else {
    std::copy(inner_char_buf.begin(), inner_char_buf.begin() + out_len, reinterpret_cast<char*>(RAW(model_str)));
971
  }
972
  UNPROTECT(2);
973
  return model_str;
974
  R_API_END();
975
976
}

977
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
978
  SEXP num_iteration,
979
  SEXP feature_importance_type) {
980
981
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
982
  _AssertBoosterHandleNotNull(handle);
983
  SEXP model_str;
984
  int64_t out_len = 0;
985
  int64_t buf_len = 1024 * 1024;
986
987
  int num_iter = Rf_asInteger(num_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
988
  std::vector<char> inner_char_buf(buf_len);
989
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
990
991
992
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
993
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
994
  }
995
996
997
  model_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(model_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
998
  return model_str;
999
  R_API_END();
Guolin Ke's avatar
Guolin Ke committed
1000
}
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
SEXP LGBM_DumpParamAliases_R() {
  SEXP cont_token = PROTECT(R_MakeUnwindCont());
  R_API_BEGIN();
  SEXP aliases_str;
  int64_t out_len = 0;
  int64_t buf_len = 1024 * 1024;
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_DumpParamAliases(buf_len, &out_len, inner_char_buf.data()));
  // if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_DumpParamAliases(out_len, &out_len, inner_char_buf.data()));
  }
  aliases_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
  SET_STRING_ELT(aliases_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
  UNPROTECT(2);
  return aliases_str;
  R_API_END();
}

1022
1023
// .Call() calls
static const R_CallMethodDef CallEntries[] = {
1024
1025
1026
1027
1028
  {"LGBM_HandleIsNull_R"              , (DL_FUNC) &LGBM_HandleIsNull_R              , 1},
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 3},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 8},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 5},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 4},
1029
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
1030
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 1},
1031
1032
1033
1034
1035
1036
1037
1038
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
1039
  {"LGBM_DatasetGetFeatureNumBin_R"   , (DL_FUNC) &LGBM_DatasetGetFeatureNumBin_R   , 3},
1040
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 2},
1041
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
1042
1043
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 1},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 1},
1044
1045
1046
1047
1048
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
1049
  {"LGBM_BoosterGetNumFeature_R"      , (DL_FUNC) &LGBM_BoosterGetNumFeature_R      , 1},
1050
1051
1052
1053
1054
1055
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
1056
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 1},
1057
1058
1059
1060
1061
1062
1063
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
1064
  {"LGBM_BoosterPredictSparseOutput_R", (DL_FUNC) &LGBM_BoosterPredictSparseOutput_R, 10},
1065
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
1066
1067
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 3},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 3},
1068
  {"LGBM_NullBoosterHandleError_R"    , (DL_FUNC) &LGBM_NullBoosterHandleError_R    , 0},
1069
  {"LGBM_DumpParamAliases_R"          , (DL_FUNC) &LGBM_DumpParamAliases_R          , 0},
1070
1071
1072
  {NULL, NULL, 0}
};

1073
1074
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

1075
1076
1077
1078
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}