lightgbm_R.cpp 22.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6

#include "lightgbm_R.h"
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
11
12
13
14
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

#include <R_ext/Rdynload.h>

15
16
17
18
#define R_NO_REMAP
#define R_USE_C99_IN_CXX
#include <R_ext/Error.h>

19
20
21
22
23
24
25
#include <string>
#include <cstdio>
#include <cstring>
#include <memory>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
31
32
33
34
  catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \
  catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \
  catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \
  return R_NilValue;
Guolin Ke's avatar
Guolin Ke committed
35
36
37

#define CHECK_CALL(x) \
  if ((x) != 0) { \
38
    Rf_error(LGBM_GetLastError()); \
39
    return R_NilValue; \
Guolin Ke's avatar
Guolin Ke committed
40
41
  }

42
43
44
using LightGBM::Common::Join;
using LightGBM::Common::Split;
using LightGBM::Log;
Guolin Ke's avatar
Guolin Ke committed
45

Guolin Ke's avatar
Guolin Ke committed
46
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len, size_t str_len) {
Guolin Ke's avatar
Guolin Ke committed
47
  if (str_len > INT32_MAX) {
48
    Log::Fatal("Don't support large string in R-package");
Guolin Ke's avatar
Guolin Ke committed
49
50
  }
  R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
51
52
53
  if (R_AS_INT(buf_len) < static_cast<int>(str_len)) {
    return dest;
  }
Guolin Ke's avatar
Guolin Ke committed
54
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
55
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
56
57
58
  return dest;
}

Guolin Ke's avatar
Guolin Ke committed
59
LGBM_SE LGBM_GetLastError_R(LGBM_SE buf_len, LGBM_SE actual_len, LGBM_SE err_msg) {
Guolin Ke's avatar
Guolin Ke committed
60
  return EncodeChar(err_msg, LGBM_GetLastError(), buf_len, actual_len, std::strlen(LGBM_GetLastError()) + 1);
Guolin Ke's avatar
Guolin Ke committed
61
62
}

63
SEXP LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
Guolin Ke's avatar
Guolin Ke committed
64
65
  LGBM_SE parameters,
  LGBM_SE reference,
66
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
67
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
68
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
74
  CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

75
SEXP LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
82
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE parameters,
  LGBM_SE reference,
83
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
89
90
91
  R_API_BEGIN();
  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
  int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
  int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
Guolin Ke's avatar
Guolin Ke committed
92
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
    nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

100
SEXP LGBM_DatasetCreateFromMat_R(LGBM_SE data,
Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE parameters,
  LGBM_SE reference,
105
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
  R_API_BEGIN();
  int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
  int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
  double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
110
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
    R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

117
SEXP LGBM_DatasetGetSubset_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
118
119
120
  LGBM_SE used_row_indices,
  LGBM_SE len_used_row_indices,
  LGBM_SE parameters,
121
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
  R_API_BEGIN();
  int len = R_AS_INT(len_used_row_indices);
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
Guolin Ke's avatar
Guolin Ke committed
126
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
127
128
129
  for (int i = 0; i < len; ++i) {
    idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
  }
Guolin Ke's avatar
Guolin Ke committed
130
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
137
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
    idxvec.data(), len, R_CHAR_PTR(parameters),
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

138
139
SEXP LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
  LGBM_SE feature_names) {
Guolin Ke's avatar
Guolin Ke committed
140
  R_API_BEGIN();
141
  auto vec_names = Split(R_CHAR_PTR(feature_names), '\t');
Guolin Ke's avatar
Guolin Ke committed
142
143
144
145
146
147
148
149
150
151
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

152
SEXP LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
153
154
  LGBM_SE buf_len,
  LGBM_SE actual_len,
155
  LGBM_SE feature_names) {
Guolin Ke's avatar
Guolin Ke committed
156
157
158
  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
159
  const size_t reserved_string_size = 256;
Guolin Ke's avatar
Guolin Ke committed
160
161
162
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
163
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
164
165
166
    ptr_names[i] = names[i].data();
  }
  int out_len;
167
168
169
170
171
172
173
  size_t required_string_size;
  CHECK_CALL(
    LGBM_DatasetGetFeatureNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
174
  CHECK_EQ(len, out_len);
175
  CHECK_GE(reserved_string_size, required_string_size);
176
  auto merge_str = Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
177
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
178
179
180
  R_API_END();
}

181
182
SEXP LGBM_DatasetSaveBinary_R(LGBM_SE handle,
  LGBM_SE filename) {
Guolin Ke's avatar
Guolin Ke committed
183
184
185
186
187
188
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
    R_CHAR_PTR(filename)));
  R_API_END();
}

189
SEXP LGBM_DatasetFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
194
195
196
197
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

198
SEXP LGBM_DatasetSetField_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
199
200
  LGBM_SE field_name,
  LGBM_SE field_data,
201
  LGBM_SE num_element) {
Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
206
  R_API_BEGIN();
  int len = static_cast<int>(R_AS_INT(num_element));
  const char* name = R_CHAR_PTR(field_name);
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
207
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
212
  } else if (!strcmp("init_score", name)) {
Guolin Ke's avatar
Guolin Ke committed
213
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
214
215
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
216
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
217
218
219
220
221
222
223
224
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

225
SEXP LGBM_DatasetGetField_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
226
  LGBM_SE field_name,
227
  LGBM_SE field_data) {
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
232
233
234
235
236
237
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
238
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
239
240
241
    for (int i = 0; i < out_len - 1; ++i) {
      R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
242
243
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
244
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
245
246
247
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
248
249
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
250
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
251
252
253
254
255
256
257
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
  }
  R_API_END();
}

258
SEXP LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
259
  LGBM_SE field_name,
260
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
265
266
267
268
269
270
271
272
273
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
  R_INT_PTR(out)[0] = static_cast<int>(out_len);
  R_API_END();
}

274
275
SEXP LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params,
  LGBM_SE new_params) {
276
  R_API_BEGIN();
277
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(R_CHAR_PTR(old_params), R_CHAR_PTR(new_params)));
278
279
280
  R_API_END();
}

281
SEXP LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
282
283
284
285
286
287
288
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
  R_INT_PTR(out)[0] = static_cast<int>(nrow);
  R_API_END();
}

289
290
SEXP LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
291
292
293
294
295
296
297
298
299
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
  R_INT_PTR(out)[0] = static_cast<int>(nfeature);
  R_API_END();
}

// --- start Booster interfaces

300
SEXP LGBM_BoosterFree_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
301
302
303
304
305
306
307
308
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

309
SEXP LGBM_BoosterCreate_R(LGBM_SE train_data,
Guolin Ke's avatar
Guolin Ke committed
310
  LGBM_SE parameters,
311
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
312
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
313
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

319
320
SEXP LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
321
322
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
323
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
324
325
326
327
328
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

329
330
SEXP LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
  LGBM_SE out) {
331
332
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
333
  BoosterHandle handle = nullptr;
334
335
336
337
338
  CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

339
340
SEXP LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle) {
Guolin Ke's avatar
Guolin Ke committed
341
342
343
344
345
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

346
347
SEXP LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data) {
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

353
354
SEXP LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data) {
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
359
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

360
361
SEXP LGBM_BoosterResetParameter_R(LGBM_SE handle,
  LGBM_SE parameters) {
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
366
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
  R_API_END();
}

367
368
SEXP LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
369
370
371
372
373
374
375
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
  R_INT_PTR(out)[0] = static_cast<int>(num_class);
  R_API_END();
}

376
SEXP LGBM_BoosterUpdateOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
381
382
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

383
SEXP LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
384
385
  LGBM_SE grad,
  LGBM_SE hess,
386
  LGBM_SE len) {
Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
  int is_finished = 0;
  R_API_BEGIN();
  int int_len = R_AS_INT(len);
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
391
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
392
393
394
395
396
397
398
399
  for (int j = 0; j < int_len; ++j) {
    tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
    thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

400
SEXP LGBM_BoosterRollbackOneIter_R(LGBM_SE handle) {
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
405
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

406
407
SEXP LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
408
409
410
411
412
413
414
  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
  R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
  R_API_END();
}

415
416
SEXP LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result) {
417
418
419
420
421
422
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

423
424
SEXP LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result) {
425
426
427
428
429
430
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

431
SEXP LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
432
433
  LGBM_SE buf_len,
  LGBM_SE actual_len,
434
  LGBM_SE eval_names) {
Guolin Ke's avatar
Guolin Ke committed
435
436
437
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
438
439

  const size_t reserved_string_size = 128;
Guolin Ke's avatar
Guolin Ke committed
440
441
442
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
443
    names[i].resize(reserved_string_size);
Guolin Ke's avatar
Guolin Ke committed
444
445
    ptr_names[i] = names[i].data();
  }
446

Guolin Ke's avatar
Guolin Ke committed
447
  int out_len;
448
449
450
451
452
453
454
  size_t required_string_size;
  CHECK_CALL(
    LGBM_BoosterGetEvalNames(
      R_GET_PTR(handle),
      len, &out_len,
      reserved_string_size, &required_string_size,
      ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
455
  CHECK_EQ(out_len, len);
456
  CHECK_GE(reserved_string_size, required_string_size);
457
  auto merge_names = Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
458
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
459
460
461
  R_API_END();
}

462
SEXP LGBM_BoosterGetEval_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
463
  LGBM_SE data_idx,
464
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
465
466
467
468
469
470
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  double* ptr_ret = R_REAL_PTR(out_result);
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
471
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
472
473
474
  R_API_END();
}

475
SEXP LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
476
  LGBM_SE data_idx,
477
  LGBM_SE out) {
Guolin Ke's avatar
Guolin Ke committed
478
479
480
  R_API_BEGIN();
  int64_t len;
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
481
  R_INT_PTR(out)[0] = static_cast<int>(len);
Guolin Ke's avatar
Guolin Ke committed
482
483
484
  R_API_END();
}

485
SEXP LGBM_BoosterGetPredict_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
486
  LGBM_SE data_idx,
487
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
488
489
490
491
492
493
494
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  R_API_END();
}

495
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
496
497
498
499
500
501
502
  int pred_type = C_API_PREDICT_NORMAL;
  if (R_AS_INT(is_rawscore)) {
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
  if (R_AS_INT(is_leafidx)) {
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
503
504
505
  if (R_AS_INT(is_predcontrib)) {
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
506
507
508
  return pred_type;
}

509
SEXP LGBM_BoosterPredictForFile_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
  LGBM_SE data_filename,
  LGBM_SE data_has_header,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
514
  LGBM_SE is_predcontrib,
515
  LGBM_SE start_iteration,
Guolin Ke's avatar
Guolin Ke committed
516
  LGBM_SE num_iteration,
517
  LGBM_SE parameter,
518
  LGBM_SE result_filename) {
Guolin Ke's avatar
Guolin Ke committed
519
  R_API_BEGIN();
520
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
521
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
522
    R_AS_INT(data_has_header), pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
Guolin Ke's avatar
Guolin Ke committed
523
524
525
526
    R_CHAR_PTR(result_filename)));
  R_API_END();
}

527
SEXP LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
528
529
530
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
531
  LGBM_SE is_predcontrib,
532
  LGBM_SE start_iteration,
Guolin Ke's avatar
Guolin Ke committed
533
  LGBM_SE num_iteration,
534
  LGBM_SE out_len) {
Guolin Ke's avatar
Guolin Ke committed
535
  R_API_BEGIN();
536
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
537
538
  int64_t len = 0;
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
539
    pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), &len));
Guolin Ke's avatar
Guolin Ke committed
540
541
542
543
  R_INT_PTR(out_len)[0] = static_cast<int>(len);
  R_API_END();
}

544
SEXP LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
545
546
547
548
549
550
551
552
  LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
553
  LGBM_SE is_predcontrib,
554
  LGBM_SE start_iteration,
Guolin Ke's avatar
Guolin Ke committed
555
  LGBM_SE num_iteration,
556
  LGBM_SE parameter,
557
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
558
  R_API_BEGIN();
559
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
560
561
562
563
564
565
566
567
568
569
570
571
572

  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = R_AS_INT(num_indptr);
  int64_t ndata = R_AS_INT(nelem);
  int64_t nrow = R_AS_INT(num_row);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
573
    nrow, pred_type,  R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
574
575
576
  R_API_END();
}

577
SEXP LGBM_BoosterPredictForMat_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
578
579
580
581
582
  LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
583
  LGBM_SE is_predcontrib,
584
  LGBM_SE start_iteration,
Guolin Ke's avatar
Guolin Ke committed
585
  LGBM_SE num_iteration,
586
  LGBM_SE parameter,
587
  LGBM_SE out_result) {
Guolin Ke's avatar
Guolin Ke committed
588
  R_API_BEGIN();
589
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
590
591
592
593

  int32_t nrow = R_AS_INT(num_row);
  int32_t ncol = R_AS_INT(num_col);

594
  const double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
595
596
597
598
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
599
    pred_type, R_AS_INT(start_iteration), R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
600
601
602
603

  R_API_END();
}

604
SEXP LGBM_BoosterSaveModel_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
605
  LGBM_SE num_iteration,
606
  LGBM_SE feature_importance_type,
607
  LGBM_SE filename) {
Guolin Ke's avatar
Guolin Ke committed
608
  R_API_BEGIN();
609
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_CHAR_PTR(filename)));
Guolin Ke's avatar
Guolin Ke committed
610
611
612
  R_API_END();
}

613
SEXP LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
614
  LGBM_SE num_iteration,
615
  LGBM_SE feature_importance_type,
616
617
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
618
  LGBM_SE out_str) {
619
  R_API_BEGIN();
620
  int64_t out_len = 0;
621
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
622
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
623
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
624
625
626
  R_API_END();
}

627
SEXP LGBM_BoosterDumpModel_R(LGBM_SE handle,
Guolin Ke's avatar
Guolin Ke committed
628
  LGBM_SE num_iteration,
629
  LGBM_SE feature_importance_type,
Guolin Ke's avatar
Guolin Ke committed
630
631
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
632
  LGBM_SE out_str) {
Guolin Ke's avatar
Guolin Ke committed
633
  R_API_BEGIN();
634
  int64_t out_len = 0;
Guolin Ke's avatar
Guolin Ke committed
635
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
636
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
637
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
Guolin Ke's avatar
Guolin Ke committed
638
639
  R_API_END();
}
640
641
642
643

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
  {"LGBM_GetLastError_R"              , (DL_FUNC) &LGBM_GetLastError_R              , 3},
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
  {"LGBM_DatasetCreateFromFile_R"     , (DL_FUNC) &LGBM_DatasetCreateFromFile_R     , 4},
  {"LGBM_DatasetCreateFromCSC_R"      , (DL_FUNC) &LGBM_DatasetCreateFromCSC_R      , 9},
  {"LGBM_DatasetCreateFromMat_R"      , (DL_FUNC) &LGBM_DatasetCreateFromMat_R      , 6},
  {"LGBM_DatasetGetSubset_R"          , (DL_FUNC) &LGBM_DatasetGetSubset_R          , 5},
  {"LGBM_DatasetSetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetSetFeatureNames_R    , 2},
  {"LGBM_DatasetGetFeatureNames_R"    , (DL_FUNC) &LGBM_DatasetGetFeatureNames_R    , 4},
  {"LGBM_DatasetSaveBinary_R"         , (DL_FUNC) &LGBM_DatasetSaveBinary_R         , 2},
  {"LGBM_DatasetFree_R"               , (DL_FUNC) &LGBM_DatasetFree_R               , 1},
  {"LGBM_DatasetSetField_R"           , (DL_FUNC) &LGBM_DatasetSetField_R           , 4},
  {"LGBM_DatasetGetFieldSize_R"       , (DL_FUNC) &LGBM_DatasetGetFieldSize_R       , 3},
  {"LGBM_DatasetGetField_R"           , (DL_FUNC) &LGBM_DatasetGetField_R           , 3},
  {"LGBM_DatasetUpdateParamChecking_R", (DL_FUNC) &LGBM_DatasetUpdateParamChecking_R, 2},
  {"LGBM_DatasetGetNumData_R"         , (DL_FUNC) &LGBM_DatasetGetNumData_R         , 2},
  {"LGBM_DatasetGetNumFeature_R"      , (DL_FUNC) &LGBM_DatasetGetNumFeature_R      , 2},
  {"LGBM_BoosterCreate_R"             , (DL_FUNC) &LGBM_BoosterCreate_R             , 3},
  {"LGBM_BoosterFree_R"               , (DL_FUNC) &LGBM_BoosterFree_R               , 1},
  {"LGBM_BoosterCreateFromModelfile_R", (DL_FUNC) &LGBM_BoosterCreateFromModelfile_R, 2},
  {"LGBM_BoosterLoadModelFromString_R", (DL_FUNC) &LGBM_BoosterLoadModelFromString_R, 2},
  {"LGBM_BoosterMerge_R"              , (DL_FUNC) &LGBM_BoosterMerge_R              , 2},
  {"LGBM_BoosterAddValidData_R"       , (DL_FUNC) &LGBM_BoosterAddValidData_R       , 2},
  {"LGBM_BoosterResetTrainingData_R"  , (DL_FUNC) &LGBM_BoosterResetTrainingData_R  , 2},
  {"LGBM_BoosterResetParameter_R"     , (DL_FUNC) &LGBM_BoosterResetParameter_R     , 2},
  {"LGBM_BoosterGetNumClasses_R"      , (DL_FUNC) &LGBM_BoosterGetNumClasses_R      , 2},
  {"LGBM_BoosterUpdateOneIter_R"      , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R      , 1},
  {"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
  {"LGBM_BoosterRollbackOneIter_R"    , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R    , 1},
  {"LGBM_BoosterGetCurrentIteration_R", (DL_FUNC) &LGBM_BoosterGetCurrentIteration_R, 2},
  {"LGBM_BoosterGetUpperBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetUpperBoundValue_R , 2},
  {"LGBM_BoosterGetLowerBoundValue_R" , (DL_FUNC) &LGBM_BoosterGetLowerBoundValue_R , 2},
  {"LGBM_BoosterGetEvalNames_R"       , (DL_FUNC) &LGBM_BoosterGetEvalNames_R       , 4},
  {"LGBM_BoosterGetEval_R"            , (DL_FUNC) &LGBM_BoosterGetEval_R            , 3},
  {"LGBM_BoosterGetNumPredict_R"      , (DL_FUNC) &LGBM_BoosterGetNumPredict_R      , 3},
  {"LGBM_BoosterGetPredict_R"         , (DL_FUNC) &LGBM_BoosterGetPredict_R         , 3},
  {"LGBM_BoosterPredictForFile_R"     , (DL_FUNC) &LGBM_BoosterPredictForFile_R     , 10},
  {"LGBM_BoosterCalcNumPredict_R"     , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R     , 8},
  {"LGBM_BoosterPredictForCSC_R"      , (DL_FUNC) &LGBM_BoosterPredictForCSC_R      , 14},
  {"LGBM_BoosterPredictForMat_R"      , (DL_FUNC) &LGBM_BoosterPredictForMat_R      , 11},
  {"LGBM_BoosterSaveModel_R"          , (DL_FUNC) &LGBM_BoosterSaveModel_R          , 4},
  {"LGBM_BoosterSaveModelToString_R"  , (DL_FUNC) &LGBM_BoosterSaveModelToString_R  , 6},
  {"LGBM_BoosterDumpModel_R"          , (DL_FUNC) &LGBM_BoosterDumpModel_R          , 6},
684
685
686
  {NULL, NULL, 0}
};

687
688
LIGHTGBM_C_EXPORT void R_init_lightgbm(DllInfo *dll);

689
690
691
692
void R_init_lightgbm(DllInfo *dll) {
  R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
  R_useDynamicSymbols(dll, FALSE);
}