lightgbm_R.cpp 22.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
25
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
31
32
33
34
  catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
  catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
  catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
  return R_NilValue;
Guolin Ke's avatar
Guolin Ke committed
35
36
37

#define CHECK_CALL(x) \
  if ((x) != 0) { \
38
    Rf_error(LGBM_GetLastError()); \
39
    return R_NilValue; \
Guolin Ke's avatar
Guolin Ke committed
40
41
  }

42
43
44
using LightGBM::Common::Join;
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
45

46
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, SEXP buf_len, LGBM_SE actual_len, size_t str_len) {
Guolin Ke's avatar
Guolin Ke committed
47
  if (str_len > INT32_MAX) {
48
    Log::Fatal("Don't support large string in R-package");
Guolin Ke's avatar
Guolin Ke committed
49
50
  }
  R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
51
  if (Rf_asInteger(buf_len) < static_cast<int>(str_len)) {
52
53
    return dest;
  }
Guolin Ke's avatar
Guolin Ke committed
54
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
55
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
56
57
58
  return dest;
}

59
60
61
62
63
64
SEXP LGBM_GetLastError_R() {
  SEXP out;
  out = PROTECT(Rf_allocVector(STRSXP, 1));
  SET_STRING_ELT(out, 0, Rf_mkChar(LGBM_GetLastError()));
  UNPROTECT(1);
  return out;
Guolin Ke's avatar
Guolin Ke committed
65
66
}

67
SEXP LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
Guolin Ke's avatar
Guolin Ke committed
68
69
  LGBM_SE parameters,
  LGBM_SE reference,
70
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
71
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
72
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
73
74
75
76
77
78
  CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

79
SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
Guolin Ke's avatar
Guolin Ke committed
80
81
  LGBM_SE indices,
  LGBM_SE data,
82
83
84
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
Guolin Ke's avatar
Guolin Ke committed
85
86
  LGBM_SE parameters,
  LGBM_SE reference,
87
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
92
  R_API_BEGIN();
  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

93
94
95
  int64_t nindptr = static_cast<int64_t>(Rf_asInteger(num_indptr));
  int64_t ndata = static_cast<int64_t>(Rf_asInteger(nelem));
  int64_t nrow = static_cast<int64_t>(Rf_asInteger(num_row));
Guolin Ke's avatar
Guolin Ke committed
96
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
    nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

104
SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data,
105
106
  SEXP num_row,
  SEXP num_col,
Guolin Ke's avatar
Guolin Ke committed
107
108
  LGBM_SE parameters,
  LGBM_SE reference,
109
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
110
  R_API_BEGIN();
111
112
  int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
  int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
Guolin Ke's avatar
Guolin Ke committed
113
  double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
114
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
    R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

121
SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
122
  LGBM_SE used_row_indices,
123
  SEXP len_used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
124
  LGBM_SE parameters,
125
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
126
  R_API_BEGIN();
127
  int len = Rf_asInteger(len_used_row_indices);
Guolin Ke's avatar
Guolin Ke committed
128
129
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
Guolin Ke's avatar
Guolin Ke committed
130
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
131
132
133
  for (int i = 0; i < len; ++i) {
    idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
  }
Guolin Ke's avatar
Guolin Ke committed
134
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
135
136
137
138
139
140
141
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
    idxvec.data(), len, R_CHAR_PTR(parameters),
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

142
143
SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
  LGBM_SE feature_names) {
Guolin Ke's avatar
Guolin Ke committed
144
  R_API_BEGIN();
145
  auto vec_names = Split(R_CHAR_PTR(feature_names), '\t');
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
153
154
155
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

156
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
157
  SEXP buf_len,
Guolin Ke's avatar
Guolin Ke committed
158
  LGBM_SE actual_len,
159
  LGBM_SE feature_names) {
Guolin Ke's avatar
Guolin Ke committed
160
161
162
  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
163
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
164
165
166
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
167
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
168
169
170
    ptr_names[i] = names[i].data();
  }
  int out_len;
171
172
173
174
175
176
177
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
178
  CHECK_EQ(len, out_len);
179
  CHECK_GE(reserved_string_size, required_string_size);
180
  auto merge_str = Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
181
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
182
183
184
  R_API_END();
}

185
186
SEXP LGBM_DatasetSaveBinary_R(LGBM_SE handle,
  LGBM_SE filename) {
Guolin Ke's avatar
Guolin Ke committed
187
188
189
190
191
192
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
    R_CHAR_PTR(filename)));
  R_API_END();
}

193
SEXP LGBM_DatasetFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
199
200
201
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

202
SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
203
204
  LGBM_SE field_name,
  LGBM_SE field_data,
205
  SEXP num_element) {
Guolin Ke's avatar
Guolin Ke committed
206
  R_API_BEGIN();
207
  int len = static_cast<int>(Rf_asInteger(num_element));
Guolin Ke's avatar
Guolin Ke committed
208
209
210
  const char* name = R_CHAR_PTR(field_name);
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
211
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
216
  } else if (!strcmp("init_score", name)) {
Guolin Ke's avatar
Guolin Ke committed
217
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
218
219
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
220
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
221
222
223
224
225
226
227
228
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

229
SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
230
  LGBM_SE field_name,
231
  LGBM_SE field_data) {
Guolin Ke's avatar
Guolin Ke committed
232
233
234
235
236
237
238
239
240
241
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
242
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
243
244
245
    for (int i = 0; i < out_len - 1; ++i) {
      R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
246
247
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
248
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
249
250
251
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
252
253
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
254
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
259
260
261
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
  }
  R_API_END();
}

262
SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
263
  LGBM_SE field_name,
264
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
265
266
267
268
269
270
271
272
273
274
275
276
277
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
  R_INT_PTR(out)[0] = static_cast<int>(out_len);
  R_API_END();
}

278
279
SEXP LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params,
  LGBM_SE new_params) {
280
  R_API_BEGIN();
281
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(R_CHAR_PTR(old_params), R_CHAR_PTR(new_params)));
282
283
284
  R_API_END();
}

285
SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
286
287
288
289
290
291
292
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
  R_INT_PTR(out)[0] = static_cast<int>(nrow);
  R_API_END();
}

293
294
SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299
300
301
302
303
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
  R_INT_PTR(out)[0] = static_cast<int>(nfeature);
  R_API_END();
}

// --- start Booster interfaces

304
SEXP LGBM_BoosterFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308
309
310
311
312
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

313
SEXP LGBM_BoosterCreate_R(LGBM_SE train_data,
Guolin Ke's avatar
Guolin Ke committed
314
  LGBM_SE parameters,
315
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
316
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
317
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
318
319
320
321
322
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

323
324
SEXP LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
325
326
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
327
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

333
334
SEXP LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
  LGBM_SE out) {
335
336
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
337
  BoosterHandle handle = nullptr;
338
339
340
341
342
  CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

343
344
SEXP LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle) {
Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
349
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

350
351
SEXP LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data) {
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
356
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

357
358
SEXP LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data) {
Guolin Ke's avatar
Guolin Ke committed
359
360
361
362
363
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

364
365
SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle,
  LGBM_SE parameters) {
Guolin Ke's avatar
Guolin Ke committed
366
367
368
369
370
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
  R_API_END();
}

371
372
SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
377
378
379
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
  R_INT_PTR(out)[0] = static_cast<int>(num_class);
  R_API_END();
}

380
SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
381
382
383
384
385
386
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

387
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
388
389
  LGBM_SE grad,
  LGBM_SE hess,
390
  SEXP len) {
Guolin Ke's avatar
Guolin Ke committed
391
392
  int is_finished = 0;
  R_API_BEGIN();
393
  int int_len = Rf_asInteger(len);
Guolin Ke's avatar
Guolin Ke committed
394
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
395
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
396
397
398
399
400
401
402
403
  for (int j = 0; j < int_len; ++j) {
    tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
    thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

404
SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
405
406
407
408
409
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

410
411
SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
412
413
414
415
416
417
418
  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
  R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
  R_API_END();
}

419
420
SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result) {
421
422
423
424
425
426
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

427
428
SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result) {
429
430
431
432
433
434
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

435
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
436
  SEXP buf_len,
Guolin Ke's avatar
Guolin Ke committed
437
  LGBM_SE actual_len,
438
  LGBM_SE eval_names) {
Guolin Ke's avatar
Guolin Ke committed
439
440
441
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
442
443

  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
444
445
446
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
447
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
448
449
    ptr_names[i] = names[i].data();
  }
450

Guolin Ke's avatar
Guolin Ke committed
451
  int out_len;
452
453
454
455
456
457
458
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
459
  CHECK_EQ(out_len, len);
460
  CHECK_GE(reserved_string_size, required_string_size);
461
  auto merge_names = Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
462
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
463
464
465
  R_API_END();
}

466
SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
467
  SEXP data_idx,
468
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
469
470
471
472
473
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  double* ptr_ret = R_REAL_PTR(out_result);
  int out_len;
474
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
475
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
476
477
478
  R_API_END();
}

479
SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
480
  SEXP data_idx,
481
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
482
483
  R_API_BEGIN();
  int64_t len;
484
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &len));
485
  R_INT_PTR(out)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
486
487
488
  R_API_END();
}

489
SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle,
490
  SEXP data_idx,
491
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
492
493
494
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
495
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
496
497
498
  R_API_END();
}

499
int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
500
  int pred_type = C_API_PREDICT_NORMAL;
501
  if (Rf_asInteger(is_rawscore)) {
Guolin Ke's avatar
Guolin Ke committed
502
503
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
504
  if (Rf_asInteger(is_leafidx)) {
Guolin Ke's avatar
Guolin Ke committed
505
506
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
507
  if (Rf_asInteger(is_predcontrib)) {
508
509
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
510
511
512
  return pred_type;
}

513
SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
514
  LGBM_SE data_filename,
515
516
517
518
519
520
  SEXP data_has_header,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
521
  LGBM_SE parameter,
522
  LGBM_SE result_filename) {
Guolin Ke's avatar
Guolin Ke committed
523
  R_API_BEGIN();
524
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
525
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
526
    Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), R_CHAR_PTR(parameter),
Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
    R_CHAR_PTR(result_filename)));
  R_API_END();
}

531
SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
532
533
534
535
536
537
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
538
  LGBM_SE out_len) {
Guolin Ke's avatar
Guolin Ke committed
539
  R_API_BEGIN();
540
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
541
  int64_t len = 0;
542
543
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), Rf_asInteger(num_row),
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len));
Guolin Ke's avatar
Guolin Ke committed
544
545
546
547
  R_INT_PTR(out_len)[0] = static_cast<int>(len);
  R_API_END();
}

548
SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
549
550
551
  LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
552
553
554
555
556
557
558
559
  SEXP num_indptr,
  SEXP nelem,
  SEXP num_row,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
560
  LGBM_SE parameter,
561
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
562
  R_API_BEGIN();
563
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
564
565
566
567
568

  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

569
570
571
  int64_t nindptr = Rf_asInteger(num_indptr);
  int64_t ndata = Rf_asInteger(nelem);
  int64_t nrow = Rf_asInteger(num_row);
Guolin Ke's avatar
Guolin Ke committed
572
573
574
575
576
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
577
    nrow, pred_type,  Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
578
579
580
  R_API_END();
}

581
SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
582
  LGBM_SE data,
583
584
585
586
587
588
589
  SEXP num_row,
  SEXP num_col,
  SEXP is_rawscore,
  SEXP is_leafidx,
  SEXP is_predcontrib,
  SEXP start_iteration,
  SEXP num_iteration,
590
  LGBM_SE parameter,
591
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
592
  R_API_BEGIN();
593
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
594

595
596
  int32_t nrow = Rf_asInteger(num_row);
  int32_t ncol = Rf_asInteger(num_col);
Guolin Ke's avatar
Guolin Ke committed
597

598
  const double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
599
600
601
602
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
603
    pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
604
605
606
607

  R_API_END();
}

608
SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
609
610
  SEXP num_iteration,
  SEXP feature_importance_type,
611
  LGBM_SE filename) {
Guolin Ke's avatar
Guolin Ke committed
612
  R_API_BEGIN();
613
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), R_CHAR_PTR(filename)));
Guolin Ke's avatar
Guolin Ke committed
614
615
616
  R_API_END();
}

617
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
618
619
620
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP buffer_len,
621
  LGBM_SE actual_len,
622
  LGBM_SE out_str) {
623
  R_API_BEGIN();
624
  int64_t out_len = 0;
625
626
627
  int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
628
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
629
630
631
  R_API_END();
}

632
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
633
634
635
  SEXP num_iteration,
  SEXP feature_importance_type,
  SEXP buffer_len,
Guolin Ke's avatar
Guolin Ke committed
636
  LGBM_SE actual_len,
637
  LGBM_SE out_str) {
Guolin Ke's avatar
Guolin Ke committed
638
  R_API_BEGIN();
639
  int64_t out_len = 0;
640
641
642
  int64_t buf_len = static_cast<int64_t>(Rf_asInteger(buffer_len));
  std::vector<char> inner_char_buf(buf_len);
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), buf_len, &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
643
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
Guolin Ke's avatar
Guolin Ke committed
644
645
  R_API_END();
}
646
647
648

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
649
  {"LGBM_GetLastError_R"              , (DL_FUNC) &LGBM_GetLastError_R              , 0},
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 4},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 9},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 6},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 5},
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 4},
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 3},
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 2},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 2},
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 4},
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 6},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 6},
690
691
692
  {NULL, NULL, 0}
};

693
694
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

695
696
697
698
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}