lightgbm_R.cpp 23.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
25
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
31
32
33
34
  catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
  catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
  catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
  return R_NilValue;
Guolin Ke's avatar
Guolin Ke committed
35
36
37

#define CHECK_CALL(x) \
  if ((x) != 0) { \
38
    Rf_error(LGBM_GetLastError()); \
39
    return R_NilValue; \
Guolin Ke's avatar
Guolin Ke committed
40
41
  }

42
43
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
44

45
46
47
48
49
50
SEXP LGBM_GetLastError_R() {
  SEXP out;
  out = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(out, 0, Rf_mkChar(LGBM_GetLastError()));
  UNPROTECT(1);
  return out;
Guolin Ke's avatar
Guolin Ke committed
51
52
}

53
54
SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
  SEXP parameters,
Guolin Ke's avatar
Guolin Ke committed
55
  LGBM_SE reference,
56
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
57
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
58
  DatasetHandle handle = nullptr;
59
  CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)),
Guolin Ke's avatar
Guolin Ke committed
60
61
62
63
64
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

65
66
67
SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr,
  SEXP indices,
  SEXP data,
68
69
70
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
71
  SEXP parameters,
Guolin Ke's avatar
Guolin Ke committed
72
  LGBM_SE reference,
73
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
74
  R_API_BEGIN();
75
76
77
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
78

79
80
81
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
Guolin Ke's avatar
Guolin Ke committed
82
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
83
84
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
85
    nrow, CHAR(Rf_asChar(parameters)), R_GET_PTR(reference), &handle));
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  R_SET_PTR(out, handle);
  R_API_END();
}

90
SEXP LGBM_DatasetCreateFromMat_R(SEXP data,
91
92
  SEXP num_row,
  SEXP num_col,
93
  SEXP parameters,
Guolin Ke's avatar
Guolin Ke committed
94
  LGBM_SE reference,
95
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
96
  R_API_BEGIN();
97
98
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
99
  double* p_mat = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
100
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
101
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
102
    CHAR(Rf_asChar(parameters)), R_GET_PTR(reference), &handle));
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
  R_SET_PTR(out, handle);
  R_API_END();
}

107
SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
108
  SEXP used_row_indices,
109
  SEXP len_used_row_indices,
110
  SEXP parameters,
111
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
112
  R_API_BEGIN();
113
  int len = Rf_asInteger(len_used_row_indices);
Guolin Ke's avatar
Guolin Ke committed
114
115
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
Guolin Ke's avatar
Guolin Ke committed
116
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
117
  for (int i = 0; i < len; ++i) {
118
    idxvec[i] = INTEGER(used_row_indices)[i] - 1;
Guolin Ke's avatar
Guolin Ke committed
119
  }
Guolin Ke's avatar
Guolin Ke committed
120
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
121
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
122
    idxvec.data(), len, CHAR(Rf_asChar(parameters)),
Guolin Ke's avatar
Guolin Ke committed
123
124
125
126
127
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

128
SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
129
  SEXP feature_names) {
Guolin Ke's avatar
Guolin Ke committed
130
  R_API_BEGIN();
131
  auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t');
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
136
137
138
139
140
141
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

142
143
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle) {
  SEXP feature_names;
Guolin Ke's avatar
Guolin Ke committed
144
145
146
  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
147
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
148
149
150
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
151
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
152
153
154
    ptr_names[i] = names[i].data();
  }
  int out_len;
155
156
157
158
159
160
161
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
  // if any feature names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_DatasetGetFeatureNames(
        R_GET_PTR(handle),
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
178
  CHECK_EQ(len, out_len);
179
180
181
182
183
184
  feature_names = PROTECT(Rf_allocVector(STRSXP, len));
  for (int i = 0; i < len; ++i) {
    SET_STRING_ELT(feature_names, i, Rf_mkChar(ptr_names[i]));
  }
  UNPROTECT(1);
  return feature_names;
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  R_API_END();
}

188
SEXP LGBM_DatasetSaveBinary_R(LGBM_SE handle,
189
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
190
191
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
192
    CHAR(Rf_asChar(filename))));
Guolin Ke's avatar
Guolin Ke committed
193
194
195
  R_API_END();
}

196
SEXP LGBM_DatasetFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
201
202
203
204
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

205
SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
206
  SEXP field_name,
207
  SEXP field_data,
208
  SEXP num_element) {
Guolin Ke's avatar
Guolin Ke committed
209
  R_API_BEGIN();
210
  int len = static_cast<int>(Rf_asInteger(num_element));
211
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
212
213
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
214
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
215
    for (int i = 0; i < len; ++i) {
216
      vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
217
218
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
219
  } else if (!strcmp("init_score", name)) {
220
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
221
222
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
223
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
224
    for (int i = 0; i < len; ++i) {
225
      vec[i] = static_cast<float>(REAL(field_data)[i]);
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
230
231
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

232
SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
233
  SEXP field_name,
234
  SEXP field_data) {
Guolin Ke's avatar
Guolin Ke committed
235
  R_API_BEGIN();
236
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
243
244
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
245
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
246
    for (int i = 0; i < out_len - 1; ++i) {
247
      INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
Guolin Ke's avatar
Guolin Ke committed
248
    }
Guolin Ke's avatar
Guolin Ke committed
249
250
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
251
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
252
    for (int i = 0; i < out_len; ++i) {
253
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
254
    }
Guolin Ke's avatar
Guolin Ke committed
255
256
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
257
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
258
    for (int i = 0; i < out_len; ++i) {
259
      REAL(field_data)[i] = p_data[i];
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
    }
  }
  R_API_END();
}

265
SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
266
  SEXP field_name,
267
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
268
  R_API_BEGIN();
269
  const char* name = CHAR(Rf_asChar(field_name));
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274
275
276
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
277
  INTEGER(out)[0] = static_cast<int>(out_len);
Guolin Ke's avatar
Guolin Ke committed
278
279
280
  R_API_END();
}

281
282
SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,
  SEXP new_params) {
283
  R_API_BEGIN();
284
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params))));
285
286
287
  R_API_END();
}

288
SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
289
290
291
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
292
  INTEGER(out)[0] = static_cast<int>(nrow);
Guolin Ke's avatar
Guolin Ke committed
293
294
295
  R_API_END();
}

296
SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
297
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
298
299
300
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
301
  INTEGER(out)[0] = static_cast<int>(nfeature);
Guolin Ke's avatar
Guolin Ke committed
302
303
304
305
306
  R_API_END();
}

// --- start Booster interfaces

307
SEXP LGBM_BoosterFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
308
309
310
311
312
313
314
315
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

316
SEXP LGBM_BoosterCreate_R(LGBM_SE train_data,
317
  SEXP parameters,
318
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
319
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
320
  BoosterHandle handle = nullptr;
321
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), CHAR(Rf_asChar(parameters)), &handle));
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
  R_SET_PTR(out, handle);
  R_API_END();
}

326
SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename,
327
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
328
329
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
330
  BoosterHandle handle = nullptr;
331
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle));
Guolin Ke's avatar
Guolin Ke committed
332
333
334
335
  R_SET_PTR(out, handle);
  R_API_END();
}

336
SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str,
337
  LGBM_SE out) {
338
339
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
340
  BoosterHandle handle = nullptr;
341
  CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle));
342
343
344
345
  R_SET_PTR(out, handle);
  R_API_END();
}

346
347
SEXP LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle) {
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

353
354
SEXP LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data) {
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
359
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

360
361
SEXP LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data) {
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
366
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

367
SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle,
368
  SEXP parameters) {
Guolin Ke's avatar
Guolin Ke committed
369
  R_API_BEGIN();
370
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), CHAR(Rf_asChar(parameters))));
Guolin Ke's avatar
Guolin Ke committed
371
372
373
  R_API_END();
}

374
SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
375
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
379
  INTEGER(out)[0] = static_cast<int>(num_class);
Guolin Ke's avatar
Guolin Ke committed
380
381
382
  R_API_END();
}

383
SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
388
389
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

390
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
391
392
  SEXP grad,
  SEXP hess,
393
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
394
395
  int is_finished = 0;
  R_API_BEGIN();
396
  int int_len = Rf_asInteger(len);
Guolin Ke's avatar
Guolin Ke committed
397
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
398
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
399
  for (int j = 0; j < int_len; ++j) {
400
401
    tgrad[j] = static_cast<float>(REAL(grad)[j]);
    thess[j] = static_cast<float>(REAL(hess)[j]);
Guolin Ke's avatar
Guolin Ke committed
402
403
404
405
406
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

407
SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
408
409
410
411
412
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

413
SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
414
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
415
416
417
  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
418
  INTEGER(out)[0] = static_cast<int>(out_iteration);
Guolin Ke's avatar
Guolin Ke committed
419
420
421
  R_API_END();
}

422
SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
423
  SEXP out_result) {
424
  R_API_BEGIN();
425
  double* ptr_ret = REAL(out_result);
426
427
428
429
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

430
SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
431
  SEXP out_result) {
432
  R_API_BEGIN();
433
  double* ptr_ret = REAL(out_result);
434
435
436
437
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

438
439
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle) {
  SEXP eval_names;
Guolin Ke's avatar
Guolin Ke committed
440
441
442
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
443
444

  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
445
446
447
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
448
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
449
450
    ptr_names[i] = names[i].data();
  }
451

Guolin Ke's avatar
Guolin Ke committed
452
  int out_len;
453
454
455
456
457
458
459
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
  // if any eval names were larger than allocated size,
  // allow for a larger size and try again
  if (required_string_size > reserved_string_size) {
    for (int i = 0; i < len; ++i) {
      names[i].resize(required_string_size);
      ptr_names[i] = names[i].data();
    }
    CHECK_CALL(
      LGBM_BoosterGetEvalNames(
        R_GET_PTR(handle),
        len,
        &out_len,
        required_string_size,
        &required_string_size,
        ptr_names.data()));
  }
Nikita Titov's avatar
Nikita Titov committed
476
  CHECK_EQ(out_len, len);
477
478
479
480
481
482
  eval_names = PROTECT(Rf_allocVector(STRSXP, len));
  for (int i = 0; i < len; ++i) {
    SET_STRING_ELT(eval_names, i, Rf_mkChar(ptr_names[i]));
  }
  UNPROTECT(1);
  return eval_names;
Guolin Ke's avatar
Guolin Ke committed
483
484
485
  R_API_END();
}

486
SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
487
  SEXP data_idx,
488
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
489
490
491
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
492
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
493
  int out_len;
494
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
495
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
496
497
498
  R_API_END();
}

499
SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
500
  SEXP data_idx,
501
  SEXP out) {
Guolin Ke's avatar
Guolin Ke committed
502
503
  R_API_BEGIN();
  int64_t len;
504
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len));
505
  INTEGER(out)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
506
507
508
  R_API_END();
}

509
SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle,
510
  SEXP data_idx,
511
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
512
  R_API_BEGIN();
513
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
514
  int64_t out_len;
515
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
516
517
518
  R_API_END();
}

519
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
520
  int pred_type = C_API_PREDICT_NORMAL;
521
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
522
523
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
524
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
525
526
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
527
  if (Rf_asInteger(is_predcontrib)) {
528
529
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
530
531
532
  return pred_type;
}

533
SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle,
534
  SEXP data_filename,
535
536
537
538
539
540
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
541
542
  SEXP parameter,
  SEXP result_filename) {
Guolin Ke's avatar
Guolin Ke committed
543
  R_API_BEGIN();
544
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
545
546
547
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), CHAR(Rf_asChar(data_filename)),
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)),
    CHAR(Rf_asChar(result_filename))));
Guolin Ke's avatar
Guolin Ke committed
548
549
550
  R_API_END();
}

551
SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
552
553
554
555
556
557
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
558
  SEXP out_len) {
Guolin Ke's avatar
Guolin Ke committed
559
  R_API_BEGIN();
560
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
561
  int64_t len = 0;
562
563
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
564
  INTEGER(out_len)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
565
566
567
  R_API_END();
}

568
SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
569
570
571
  SEXP indptr,
  SEXP indices,
  SEXP data,
572
573
574
575
576
577
578
579
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
580
  SEXP parameter,
581
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
582
  R_API_BEGIN();
583
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
584

585
586
587
  const int* p_indptr = INTEGER(indptr);
  const int* p_indices = INTEGER(indices);
  const double* p_data = REAL(data);
Guolin Ke's avatar
Guolin Ke committed
588

589
590
591
  int64_t nindptr = Rf_asInteger(num_indptr);
  int64_t ndata = Rf_asInteger(nelem);
  int64_t nrow = Rf_asInteger(num_row);
592
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
593
594
595
596
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
597
    nrow, pred_type,  Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
598
599
600
  R_API_END();
}

601
SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
602
  SEXP data,
603
604
605
606
607
608
609
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
610
  SEXP parameter,
611
  SEXP out_result) {
Guolin Ke's avatar
Guolin Ke committed
612
  R_API_BEGIN();
613
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
614

615
616
  int32_t nrow = Rf_asInteger(num_row);
  int32_t ncol = Rf_asInteger(num_col);
Guolin Ke's avatar
Guolin Ke committed
617

618
619
  const double* p_mat = REAL(data);
  double* ptr_ret = REAL(out_result);
Guolin Ke's avatar
Guolin Ke committed
620
621
622
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
623
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
624
625
626
627

  R_API_END();
}

628
SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
629
630
  SEXP num_iteration,
  SEXP feature_importance_type,
631
  SEXP filename) {
Guolin Ke's avatar
Guolin Ke committed
632
  R_API_BEGIN();
633
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename))));
Guolin Ke's avatar
Guolin Ke committed
634
635
636
  R_API_END();
}

637
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
638
  SEXP num_iteration,
639
640
  SEXP feature_importance_type) {
  SEXP model_str;
641
  R_API_BEGIN();
642
  int64_t out_len = 0;
643
644
645
  int64_t buf_len = 1024 * 1024;
  int64_t num_iter = Rf_asInteger(num_iteration);
  int64_t importance_type = Rf_asInteger(feature_importance_type);
646
  std::vector<char> inner_char_buf(buf_len);
647
648
649
650
651
652
653
654
655
656
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
  }
  model_str = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
  UNPROTECT(1);
  return model_str;
657
658
659
  R_API_END();
}

660
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
661
  SEXP num_iteration,
662
663
  SEXP feature_importance_type) {
  SEXP model_str;
Guolin Ke's avatar
Guolin Ke committed
664
  R_API_BEGIN();
665
  int64_t out_len = 0;
666
667
668
  int64_t buf_len = 1024 * 1024;
  int64_t num_iter = Rf_asInteger(num_iteration);
  int64_t importance_type = Rf_asInteger(feature_importance_type);
669
  std::vector<char> inner_char_buf(buf_len);
670
671
672
673
674
675
676
677
678
679
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data()));
  // if the model string was larger than the initial buffer, allocate a bigger buffer and try again
  if (out_len > buf_len) {
    inner_char_buf.resize(out_len);
    CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data()));
  }
  model_str = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data()));
  UNPROTECT(1);
  return model_str;
Guolin Ke's avatar
Guolin Ke committed
680
681
  R_API_END();
}
682
683
684

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
685
  {"LGBM_GetLastError_R"              , (DL_FUNC) &LGBM_GetLastError_R              , 0},
686
687
688
689
690
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 4},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 9},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 6},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 5},
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
691
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 1},
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 3},
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 2},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 2},
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
715
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 1},
716
717
718
719
720
721
722
723
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
724
725
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 3},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 3},
726
727
728
  {NULL, NULL, 0}
};

729
730
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

731
732
733
734
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}