"vscode:/vscode.git/clone" did not exist on "8e5079efa1ecca0fde1824d8020d51e02735a541"
lightgbm_R.cpp 17.6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
#include <vector>
#include <string>
#include <utility>
#include <cstring>
#include <cstdio>
#include <sstream>
7
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
#include <cstdint>
#include <memory>

#include <LightGBM/utils/text_reader.h>
#include <LightGBM/utils/common.h>

14
#include <LightGBM/lightgbm_R.h>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28

#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {

#define R_API_END() } \
  catch(std::exception& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.what()); return call_state;} \
  catch(std::string& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.c_str()); return call_state; } \
  catch(...) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError("unknown exception"); return call_state;} \
  return call_state;

#define CHECK_CALL(x) \
  if ((x) != 0) { \
Guolin Ke's avatar
Guolin Ke committed
29
30
    R_INT_PTR(call_state)[0] = -1;\
    return call_state;\
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
  }

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
35
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len) {
Guolin Ke's avatar
Guolin Ke committed
36
37
38
39
40
  size_t str_len = std::strlen(src);
  if (str_len > INT32_MAX) {
    Log::Fatal("Don't support large string in R-package.");
  }
  R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
Guolin Ke's avatar
Guolin Ke committed
41
  if (R_AS_INT(buf_len) < static_cast<int>(str_len)) { return dest; }
Guolin Ke's avatar
Guolin Ke committed
42
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
43
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
44
45
46
  return dest;
}

Guolin Ke's avatar
Guolin Ke committed
47
LGBM_SE LGBM_GetLastError_R(LGBM_SE buf_len, LGBM_SE actual_len, LGBM_SE err_msg) {
Guolin Ke's avatar
Guolin Ke committed
48
49
50
  return EncodeChar(err_msg, LGBM_GetLastError(), buf_len, actual_len);
}

Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
LGBM_SE LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
56
57

  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
58
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
64
  CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
69
70
71
72
73
74
LGBM_SE LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
  R_API_BEGIN();
  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
  int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
  int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
Guolin Ke's avatar
Guolin Ke committed
83
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
89
90
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
    nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
96
97
LGBM_SE LGBM_DatasetCreateFromMat_R(LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
102

  R_API_BEGIN();
  int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
  int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
  double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
103
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
104
105
106
107
108
109
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
    R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
LGBM_SE LGBM_DatasetGetSubset_R(LGBM_SE handle,
  LGBM_SE used_row_indices,
  LGBM_SE len_used_row_indices,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124

  R_API_BEGIN();
  int len = R_AS_INT(len_used_row_indices);
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
#pragma omp parallel for schedule(static)
  for (int i = 0; i < len; ++i) {
    idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
  }
Guolin Ke's avatar
Guolin Ke committed
125
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
    idxvec.data(), len, R_CHAR_PTR(parameters),
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
133
134
135
LGBM_SE LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
136
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
137
  auto vec_names = Common::Split(R_CHAR_PTR(feature_names), '\t');
Guolin Ke's avatar
Guolin Ke committed
138
139
140
141
142
143
144
145
146
147
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152
LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(256);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
    ptr_names.data(), &out_len));
  CHECK(len == out_len);
  auto merge_str = Common::Join<char*>(ptr_names, "\t");
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
172
173
174
LGBM_SE LGBM_DatasetSaveBinary_R(LGBM_SE handle,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
179
180
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
    R_CHAR_PTR(filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
181
182
LGBM_SE LGBM_DatasetFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
183
184
185
186
187
188
189
190
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
195
LGBM_SE LGBM_DatasetSetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE num_element,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
200
201
202
203
204
205
  R_API_BEGIN();
  int len = static_cast<int>(R_AS_INT(num_element));
  const char* name = R_CHAR_PTR(field_name);
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
Guolin Ke's avatar
Guolin Ke committed
206
207
  } else if(!strcmp("init_score", name)) {
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
212
213
214
215
216
217
218
  } else {
    std::vector<float> vec(len);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
#pragma omp parallel for schedule(static)
    for (int i = 0; i < out_len - 1; ++i) {
      R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
248
249
250
251
252
253
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271

  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
  R_INT_PTR(out)[0] = static_cast<int>(out_len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
272
273
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
274
275
276
277
278
279
280
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
  R_INT_PTR(out)[0] = static_cast<int>(nrow);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
281
282
283
LGBM_SE LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
289
290
291
292
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
  R_INT_PTR(out)[0] = static_cast<int>(nfeature);
  R_API_END();
}

// --- start Booster interfaces

Guolin Ke's avatar
Guolin Ke committed
293
294
LGBM_SE LGBM_BoosterFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299
300
301
302
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
303
304
305
306
LGBM_SE LGBM_BoosterCreate_R(LGBM_SE train_data,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
307
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
308
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
314
315
316
LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
317
318
319

  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
320
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
325
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

326
327
328
329
330
331
LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
  LGBM_SE out,
  LGBM_SE call_state) {

  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
332
  BoosterHandle handle = nullptr;
333
334
335
336
337
  CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
338
339
340
LGBM_SE LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
341
342
343
344
345
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
346
347
348
LGBM_SE LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
349
350
351
352
353
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
354
355
356
LGBM_SE LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
357
358
359
360
361
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
362
363
364
LGBM_SE LGBM_BoosterResetParameter_R(LGBM_SE handle,
  LGBM_SE parameters,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
365
366
367
368
369
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
370
371
372
LGBM_SE LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
377
378
379
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
  R_INT_PTR(out)[0] = static_cast<int>(num_class);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
380
381
LGBM_SE LGBM_BoosterUpdateOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385
386
387
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
388
389
390
391
392
LGBM_SE LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
  LGBM_SE grad,
  LGBM_SE hess,
  LGBM_SE len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
393
394
395
396
397
398
399
400
401
402
403
404
405
  int is_finished = 0;
  R_API_BEGIN();
  int int_len = R_AS_INT(len);
  std::vector<float> tgrad(int_len), thess(int_len);
#pragma omp parallel for schedule(static)
  for (int j = 0; j < int_len; ++j) {
    tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
    thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
406
407
LGBM_SE LGBM_BoosterRollbackOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
408
409
410
411
412
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
413
414
415
LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
416
417
418
419
420
421
422
423

  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
  R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
428
LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE eval_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(128);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
  CHECK(out_len == len);
  auto merge_names = Common::Join<char*>(ptr_names, "\t");
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
447
448
449
450
LGBM_SE LGBM_BoosterGetEval_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
451
452
453
454
455
456
457
458
459
460
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  double* ptr_ret = R_REAL_PTR(out_result);
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  CHECK(out_len == len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
461
462
463
464
LGBM_SE LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
465
466
467
468
469
470
471
  R_API_BEGIN();
  int64_t len;
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
  R_INT_PTR(out)[0] = static_cast<int>(len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
472
473
474
475
LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
476
477
478
479
480
481
482
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
483
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) {
Guolin Ke's avatar
Guolin Ke committed
484
485
486
487
488
489
490
491
492
493
  int pred_type = C_API_PREDICT_NORMAL;
  if (R_AS_INT(is_rawscore)) {
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
  if (R_AS_INT(is_leafidx)) {
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
  return pred_type;
}

Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
498
499
LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
  LGBM_SE data_filename,
  LGBM_SE data_has_header,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
500
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
501
502
  LGBM_SE result_filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
503
504
505
  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
506
    R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
    R_CHAR_PTR(result_filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
511
512
513
514
515
516
517
LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
  LGBM_SE out_len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
522
523
524
525
526
  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);
  int64_t len = 0;
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
    pred_type, R_AS_INT(num_iteration), &len));
  R_INT_PTR(out_len)[0] = static_cast<int>(len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
531
532
533
534
535
536
LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
  LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
537
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
538
539
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);

  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = R_AS_INT(num_indptr);
  int64_t ndata = R_AS_INT(nelem);
  int64_t nrow = R_AS_INT(num_row);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
556
    nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
557
558
559
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
560
561
562
563
564
565
566
LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
  LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
567
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
568
569
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
570
571
572
573
574
575
576
577
578
579
580
581

  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);

  int32_t nrow = R_AS_INT(num_row);
  int32_t ncol = R_AS_INT(num_col);

  double* p_mat = R_REAL_PTR(data);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
582
    pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
583
584
585
586

  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
587
588
589
590
LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
591
592
593
594
595
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
  R_API_END();
}

596
597
598
599
600
601
602
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
  R_API_BEGIN();
603
  int64_t out_len = 0;
604
605
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
606
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
607
608
609
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
610
611
612
613
614
615
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
616
  R_API_BEGIN();
617
  int64_t out_len = 0;
Guolin Ke's avatar
Guolin Ke committed
618
619
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
620
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
Guolin Ke's avatar
Guolin Ke committed
621
622
  R_API_END();
}