lightgbm_R.cpp 25.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
25
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
31
32
33
34
  catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
  catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
  catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
  return R_NilValue;
Guolin Ke's avatar
Guolin Ke committed
35
36
37

#define CHECK_CALL(x) \
  if ((x) != 0) { \
38
    Rf_error(LGBM_GetLastError()); \
39
    return R_NilValue; \
Guolin Ke's avatar
Guolin Ke committed
40
41
  }

42
43
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
44

45
46
47
48
SEXP LGBM_HandleIsNull_R(SEXP handle) {
  return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
}

49
50
51
52
void _DatasetFinalizer(SEXP handle) {
  LGBM_DatasetFree_R(handle);
}

53
54
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
55
56
  SEXP reference) {
  SEXP ret;
Guolin Ke's avatar
Guolin Ke committed
57
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
58
  DatasetHandle handle = nullptr;
59
60
61
62
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
63
  CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)),
64
65
    ref, &handle));
  ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
66
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
67
68
  UNPROTECT(1);
  return ret;
Guolin Ke's avatar
Guolin Ke committed
69
70
71
  R_API_END();
}

72
73
74
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
75
76
77
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
78
  SEXP parameters,
79
80
  SEXP reference) {
  SEXP ret;
Guolin Ke's avatar
Guolin Ke committed
81
  R_API_BEGIN();
82
83
84
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
85

86
87
88
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
Guolin Ke's avatar
Guolin Ke committed
89
  DatasetHandle handle = nullptr;
90
91
92
93
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
94
95
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
96
97
    nrow, CHAR(Rf_asChar(parameters)), ref, &handle));
  ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
98
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
99
100
  UNPROTECT(1);
  return ret;
Guolin Ke's avatar
Guolin Ke committed
101
102
103
  R_API_END();
}

104
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
105
106
  SEXP num_row,
  SEXP num_col,
107
  SEXP parameters,
108
109
  SEXP reference) {
  SEXP ret;
Guolin Ke's avatar
Guolin Ke committed
110
  R_API_BEGIN();
111
112
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
113
  double* p_mat = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
114
  DatasetHandle handle = nullptr;
115
116
117
118
  DatasetHandle ref = nullptr;
  if (!Rf_isNull(reference)) {
    ref = R_ExternalPtrAddr(reference);
  }
Guolin Ke's avatar
Guolin Ke committed
119
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
120
121
    CHAR(Rf_asChar(parameters)), ref, &handle));
  ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
122
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
123
124
  UNPROTECT(1);
  return ret;
Guolin Ke's avatar
Guolin Ke committed
125
126
127
  R_API_END();
}

128
SEXP LGBM_DatasetGetSubset_R(SEXP handle,
129
  SEXP used_row_indices,
130
  SEXP len_used_row_indices,
131
132
  SEXP parameters) {
  SEXP ret;
Guolin Ke's avatar
Guolin Ke committed
133
  R_API_BEGIN();
134
135
  int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
  std::vector<int32_t> idxvec(len);
136
  // convert from one-based to zero-based index
Guolin Ke's avatar
Guolin Ke committed
137
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
138
139
  for (int32_t i = 0; i < len; ++i) {
    idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
Guolin Ke's avatar
Guolin Ke committed
140
  }
Guolin Ke's avatar
Guolin Ke committed
141
  DatasetHandle res = nullptr;
142
  CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle),
143
    idxvec.data(), len, CHAR(Rf_asChar(parameters)),
Guolin Ke's avatar
Guolin Ke committed
144
    &res));
145
  ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
146
  R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE);
147
148
  UNPROTECT(1);
  return ret;
Guolin Ke's avatar
Guolin Ke committed
149
150
151
  R_API_END();
}

152
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
153
  SEXP feature_names) {
Guolin Ke's avatar
Guolin Ke committed
154
  R_API_BEGIN();
155
  auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
161
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
162
163
164
165
    vec_sptr.data(), len));
  R_API_END();
}

166
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
167
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
168
169
  R_API_BEGIN();
  int len = 0;
170
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
171
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
172
173
174
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
175
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
176
177
178
    ptr_names[i] = names[i].data();
  }
  int out_len;
179
180
181
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
182
      R_ExternalPtrAddr(handle),
183
184
185
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
186
187
188
189
190
191
192
193
194
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
195
        R_ExternalPtrAddr(handle),
196
197
198
199
200
201
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
202
  CHECK_EQ(len, out_len);
203
204
205
206
207
208
  feature_names = PROTECT(Rf_allocVector(STRSXP, len));
  for (int i = 0; i < len; ++i) {
    SET_STRING_ELT(feature_names, i, Rf_mkChar(ptr_names[i]));
  }
  UNPROTECT(1);
  return feature_names;
Guolin Ke's avatar
Guolin Ke committed
209
210
211
  R_API_END();
}

212
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
213
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
214
  R_API_BEGIN();
215
  CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
216
    CHAR(Rf_asChar(filename))));
Guolin Ke's avatar
Guolin Ke committed
217
218
219
  R_API_END();
}

220
SEXP LGBM_DatasetFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
221
  R_API_BEGIN();
222
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
223
224
    CHECK_CALL(LGBM_DatasetFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
  }
  R_API_END();
}

229
SEXP LGBM_DatasetSetField_R(SEXP handle,
230
  SEXP field_name,
231
  SEXP field_data,
232
  SEXP num_element) {
Guolin Ke's avatar
Guolin Ke committed
233
  R_API_BEGIN();
234
  int len = Rf_asInteger(num_element);
235
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
236
237
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
238
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
239
    for (int i = 0; i < len; ++i) {
240
      vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
241
    }
242
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_INT32));
243
  } else if (!strcmp("init_score", name)) {
244
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
245
246
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
247
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
248
    for (int i = 0; i < len; ++i) {
249
      vec[i] = static_cast<float>(REAL(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
250
    }
251
    CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
Guolin Ke's avatar
Guolin Ke committed
252
253
254
255
  }
  R_API_END();
}

256
SEXP LGBM_DatasetGetField_R(SEXP handle,
257
  SEXP field_name,
258
  SEXP field_data) {
Guolin Ke's avatar
Guolin Ke committed
259
  R_API_BEGIN();
260
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
261
262
263
  int out_len = 0;
  int out_type = 0;
  const void* res;
264
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
265
266
267
268

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
269
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
270
    for (int i = 0; i < out_len - 1; ++i) {
271
      INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
272
    }
Guolin Ke's avatar
Guolin Ke committed
273
274
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
275
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
276
    for (int i = 0; i < out_len; ++i) {
277
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
278
    }
Guolin Ke's avatar
Guolin Ke committed
279
280
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
281
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
282
    for (int i = 0; i < out_len; ++i) {
283
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
    }
  }
  R_API_END();
}

289
SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
290
  SEXP field_name,
291
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
292
  R_API_BEGIN();
293
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
294
295
296
  int out_len = 0;
  int out_type = 0;
  const void* res;
297
  CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type));
Guolin Ke's avatar
Guolin Ke committed
298
299
300
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
301
  INTEGER(out)[0] = out_len;
Guolin Ke's avatar
Guolin Ke committed
302
303
304
  R_API_END();
}

305
306
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
307
  R_API_BEGIN();
308
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params))));
309
310
311
  R_API_END();
}

312
SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
313
314
  int nrow;
  R_API_BEGIN();
315
  CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
316
  INTEGER(out)[0] = nrow;
Guolin Ke's avatar
Guolin Ke committed
317
318
319
  R_API_END();
}

320
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
321
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
322
323
  int nfeature;
  R_API_BEGIN();
324
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
325
  INTEGER(out)[0] = nfeature;
Guolin Ke's avatar
Guolin Ke committed
326
327
328
329
330
  R_API_END();
}

// --- start Booster interfaces

331
332
333
334
void _BoosterFinalizer(SEXP handle) {
  LGBM_BoosterFree_R(handle);
}

335
SEXP LGBM_BoosterFree_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
336
  R_API_BEGIN();
337
  if (!Rf_isNull(handle) && R_ExternalPtrAddr(handle)) {
338
339
    CHECK_CALL(LGBM_BoosterFree(R_ExternalPtrAddr(handle)));
    R_ClearExternalPtr(handle);
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
  }
  R_API_END();
}

344
345
346
SEXP LGBM_BoosterCreate_R(SEXP train_data,
  SEXP parameters) {
  SEXP ret;
Guolin Ke's avatar
Guolin Ke committed
347
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
348
  BoosterHandle handle = nullptr;
349
350
  CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle));
  ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
351
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
352
353
  UNPROTECT(1);
  return ret;
Guolin Ke's avatar
Guolin Ke committed
354
355
356
  R_API_END();
}

357
358
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) {
  SEXP ret;
Guolin Ke's avatar
Guolin Ke committed
359
360
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
361
  BoosterHandle handle = nullptr;
362
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle));
363
  ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
364
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
365
366
  UNPROTECT(1);
  return ret;
Guolin Ke's avatar
Guolin Ke committed
367
368
369
  R_API_END();
}

370
371
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
  SEXP ret;
372
373
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
374
  BoosterHandle handle = nullptr;
375
  CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle));
376
  ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
377
  R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
378
379
  UNPROTECT(1);
  return ret;
380
381
382
  R_API_END();
}

383
384
SEXP LGBM_BoosterMerge_R(SEXP handle,
  SEXP other_handle) {
Guolin Ke's avatar
Guolin Ke committed
385
  R_API_BEGIN();
386
  CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
Guolin Ke's avatar
Guolin Ke committed
387
388
389
  R_API_END();
}

390
391
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
  SEXP valid_data) {
Guolin Ke's avatar
Guolin Ke committed
392
  R_API_BEGIN();
393
  CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
Guolin Ke's avatar
Guolin Ke committed
394
395
396
  R_API_END();
}

397
398
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
  SEXP train_data) {
Guolin Ke's avatar
Guolin Ke committed
399
  R_API_BEGIN();
400
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
Guolin Ke's avatar
Guolin Ke committed
401
402
403
  R_API_END();
}

404
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
405
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
406
  R_API_BEGIN();
407
  CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(parameters))));
Guolin Ke's avatar
Guolin Ke committed
408
409
410
  R_API_END();
}

411
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
412
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
413
414
  int num_class;
  R_API_BEGIN();
415
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
416
  INTEGER(out)[0] = num_class;
Guolin Ke's avatar
Guolin Ke committed
417
418
419
  R_API_END();
}

420
SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
421
422
  int is_finished = 0;
  R_API_BEGIN();
423
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
Guolin Ke's avatar
Guolin Ke committed
424
425
426
  R_API_END();
}

427
SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
428
429
  SEXP grad,
  SEXP hess,
430
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
431
432
  int is_finished = 0;
  R_API_BEGIN();
433
  int int_len = Rf_asInteger(len);
Guolin Ke's avatar
Guolin Ke committed
434
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
435
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
436
  for (int j = 0; j < int_len; ++j) {
437
438
    tgrad[j] = static_cast<float>(REAL(grad)[j]);
    thess[j] = static_cast<float>(REAL(hess)[j]);
Guolin Ke's avatar
Guolin Ke committed
439
  }
440
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
Guolin Ke's avatar
Guolin Ke committed
441
442
443
  R_API_END();
}

444
SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
Guolin Ke's avatar
Guolin Ke committed
445
  R_API_BEGIN();
446
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
Guolin Ke's avatar
Guolin Ke committed
447
448
449
  R_API_END();
}

450
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
451
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
452
453
  int out_iteration;
  R_API_BEGIN();
454
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
455
  INTEGER(out)[0] = out_iteration;
Guolin Ke's avatar
Guolin Ke committed
456
457
458
  R_API_END();
}

459
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
460
  SEXP out_result) {
461
  R_API_BEGIN();
462
  double* ptr_ret = REAL(out_result);
463
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
464
465
466
  R_API_END();
}

467
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
468
  SEXP out_result) {
469
  R_API_BEGIN();
470
  double* ptr_ret = REAL(out_result);
471
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
472
473
474
  R_API_END();
}

475
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
476
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
477
478
  R_API_BEGIN();
  int len;
479
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
480
481

  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
482
483
484
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
485
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
486
487
    ptr_names[i] = names[i].data();
  }
488

Guolin Ke's avatar
Guolin Ke committed
489
  int out_len;
490
491
492
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
493
      R_ExternalPtrAddr(handle),
494
495
496
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
497
498
499
500
501
502
503
504
505
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
506
        R_ExternalPtrAddr(handle),
507
508
509
510
511
512
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
513
  CHECK_EQ(out_len, len);
514
515
516
517
518
519
  eval_names = PROTECT(Rf_allocVector(STRSXP, len));
  for (int i = 0; i < len; ++i) {
    SET_STRING_ELT(eval_names, i, Rf_mkChar(ptr_names[i]));
  }
  UNPROTECT(1);
  return eval_names;
Guolin Ke's avatar
Guolin Ke committed
520
521
522
  R_API_END();
}

523
SEXP LGBM_BoosterGetEval_R(SEXP handle,
524
  SEXP data_idx,
525
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
526
527
  R_API_BEGIN();
  int len;
528
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
529
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
530
  int out_len;
531
  CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
532
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
533
534
535
  R_API_END();
}

536
SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
537
  SEXP data_idx,
538
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
539
540
  R_API_BEGIN();
  int64_t len;
541
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
542
  INTEGER(out)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
543
544
545
  R_API_END();
}

546
SEXP LGBM_BoosterGetPredict_R(SEXP handle,
547
  SEXP data_idx,
548
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
549
  R_API_BEGIN();
550
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
551
  int64_t out_len;
552
  CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
553
554
555
  R_API_END();
}

556
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
557
  int pred_type = C_API_PREDICT_NORMAL;
558
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
559
560
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
561
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
562
563
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
564
  if (Rf_asInteger(is_predcontrib)) {
565
566
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
567
568
569
  return pred_type;
}

570
SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
571
  SEXP data_filename,
572
573
574
575
576
577
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
578
579
  SEXP parameter,
  SEXP result_filename) {
Guolin Ke's avatar
Guolin Ke committed
580
  R_API_BEGIN();
581
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
582
  CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(data_filename)),
583
584
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)),
    CHAR(Rf_asChar(result_filename))));
Guolin Ke's avatar
Guolin Ke committed
585
586
587
  R_API_END();
}

588
SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
589
590
591
592
593
594
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
595
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
596
  R_API_BEGIN();
597
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
598
  int64_t len = 0;
599
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
600
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
601
  INTEGER(out_len)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
602
603
604
  R_API_END();
}

605
SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
606
607
608
  SEXP indptr,
  SEXP indices,
  SEXP data,
609
610
611
612
613
614
615
616
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
617
  SEXP parameter,
618
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
619
  R_API_BEGIN();
620
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
621

622
  const int* p_indptr = INTEGER(indptr);
623
  const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
624
  const double* p_data = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
625

626
627
628
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
629
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
630
  int64_t out_len;
631
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
632
633
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
634
    nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
635
636
637
  R_API_END();
}

638
SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
639
  SEXP data,
640
641
642
643
644
645
646
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
647
  SEXP parameter,
648
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
649
  R_API_BEGIN();
650
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
651

652
653
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
Guolin Ke's avatar
Guolin Ke committed
654

655
656
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
657
  int64_t out_len;
658
  CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle),
Guolin Ke's avatar
Guolin Ke committed
659
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
660
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
661
662
663
664

  R_API_END();
}

665
SEXP LGBM_BoosterSaveModel_R(SEXP handle,
666
667
  SEXP num_iteration,
  SEXP feature_importance_type,
668
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
669
  R_API_BEGIN();
670
  CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename))));
Guolin Ke's avatar
Guolin Ke committed
671
672
673
  R_API_END();
}

674
SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
675
  SEXP num_iteration,
676
677
  SEXP feature_importance_type) {
  SEXP model_str;
678
  R_API_BEGIN();
679
  int64_t out_len = 0;
680
  int64_t buf_len = 1024 * 1024;
681
682
  int num_iter = Rf_asInteger(num_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
683
  std::vector<char> inner_char_buf(buf_len);
684
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
685
686
687
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
688
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
689
690
691
692
693
  }
  model_str = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
  UNPROTECT(1);
  return model_str;
694
695
696
  R_API_END();
}

697
SEXP LGBM_BoosterDumpModel_R(SEXP handle,
698
  SEXP num_iteration,
699
700
  SEXP feature_importance_type) {
  SEXP model_str;
Guolin Ke's avatar
Guolin Ke committed
701
  R_API_BEGIN();
702
  int64_t out_len = 0;
703
  int64_t buf_len = 1024 * 1024;
704
705
  int num_iter = Rf_asInteger(num_iteration);
  int importance_type = Rf_asInteger(feature_importance_type);
706
  std::vector<char> inner_char_buf(buf_len);
707
  CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
708
709
710
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
711
    CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
712
713
714
715
716
  }
  model_str = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
  UNPROTECT(1);
  return model_str;
Guolin Ke's avatar
Guolin Ke committed
717
718
  R_API_END();
}
719
720
721

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
722
723
724
725
726
  {"LGBM_HandleIsNull_R"              , (DL_FUNC) &LGBM_HandleIsNull_R              , 1},
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 3},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 8},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 5},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 4},
727
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
728
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 1},
729
730
731
732
733
734
735
736
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
737
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 2},
738
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
739
740
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 1},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 1},
741
742
743
744
745
746
747
748
749
750
751
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
752
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 1},
753
754
755
756
757
758
759
760
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
761
762
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 3},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 3},
763
764
765
  {NULL, NULL, 0}
};

766
767
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

768
769
770
771
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}