"src/vscode:/vscode.git/clone" did not exist on "d0bec9e946de41a5cdfc57128fbd4f948cdea449"
lightgbm_R.cpp 17.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
#include <vector>
#include <string>
#include <utility>
#include <cstring>
#include <cstdio>
#include <sstream>
7
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
#include <cstdint>
#include <memory>

#include <LightGBM/utils/text_reader.h>
#include <LightGBM/utils/common.h>

14
#include <LightGBM/lightgbm_R.h>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28

#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {

#define R_API_END() } \
  catch(std::exception& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.what()); return call_state;} \
  catch(std::string& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.c_str()); return call_state; } \
  catch(...) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError("unknown exception"); return call_state;} \
  return call_state;

#define CHECK_CALL(x) \
  if ((x) != 0) { \
Guolin Ke's avatar
Guolin Ke committed
29
30
    R_INT_PTR(call_state)[0] = -1;\
    return call_state;\
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
  }

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
35
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len, size_t str_len) {
Guolin Ke's avatar
Guolin Ke committed
36
37
38
39
  if (str_len > INT32_MAX) {
    Log::Fatal("Don't support large string in R-package.");
  }
  R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
Guolin Ke's avatar
Guolin Ke committed
40
  if (R_AS_INT(buf_len) < static_cast<int>(str_len)) { return dest; }
Guolin Ke's avatar
Guolin Ke committed
41
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
42
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
43
44
45
  return dest;
}

Guolin Ke's avatar
Guolin Ke committed
46
LGBM_SE LGBM_GetLastError_R(LGBM_SE buf_len, LGBM_SE actual_len, LGBM_SE err_msg) {
Guolin Ke's avatar
Guolin Ke committed
47
  return EncodeChar(err_msg, LGBM_GetLastError(), buf_len, actual_len, std::strlen(LGBM_GetLastError()) + 1);
Guolin Ke's avatar
Guolin Ke committed
48
49
}

Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
LGBM_SE LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
55
56

  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
57
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
  CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
69
70
71
72
73
LGBM_SE LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
79
80
81
  R_API_BEGIN();
  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
  int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
  int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
Guolin Ke's avatar
Guolin Ke committed
82
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
87
88
89
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
    nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
94
95
96
LGBM_SE LGBM_DatasetCreateFromMat_R(LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101

  R_API_BEGIN();
  int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
  int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
  double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
102
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
107
108
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
    R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
LGBM_SE LGBM_DatasetGetSubset_R(LGBM_SE handle,
  LGBM_SE used_row_indices,
  LGBM_SE len_used_row_indices,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123

  R_API_BEGIN();
  int len = R_AS_INT(len_used_row_indices);
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
#pragma omp parallel for schedule(static)
  for (int i = 0; i < len; ++i) {
    idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
  }
Guolin Ke's avatar
Guolin Ke committed
124
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
    idxvec.data(), len, R_CHAR_PTR(parameters),
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
132
133
134
LGBM_SE LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
135
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
136
  auto vec_names = Common::Split(R_CHAR_PTR(feature_names), '\t');
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
143
144
145
146
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(256);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
    ptr_names.data(), &out_len));
  CHECK(len == out_len);
  auto merge_str = Common::Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
167
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
168
169
170
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
171
172
173
LGBM_SE LGBM_DatasetSaveBinary_R(LGBM_SE handle,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
174
175
176
177
178
179
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
    R_CHAR_PTR(filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
180
181
LGBM_SE LGBM_DatasetFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
188
189
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
194
LGBM_SE LGBM_DatasetSetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE num_element,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
195
196
197
198
199
200
201
202
203
204
  R_API_BEGIN();
  int len = static_cast<int>(R_AS_INT(num_element));
  const char* name = R_CHAR_PTR(field_name);
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
Guolin Ke's avatar
Guolin Ke committed
205
206
  } else if(!strcmp("init_score", name)) {
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
207
208
209
210
211
212
213
214
215
216
217
  } else {
    std::vector<float> vec(len);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
#pragma omp parallel for schedule(static)
    for (int i = 0; i < out_len - 1; ++i) {
      R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
247
248
249
250
251
252
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
#pragma omp parallel for schedule(static)
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
253
254
255
256
LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270

  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
  R_INT_PTR(out)[0] = static_cast<int>(out_len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
271
272
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
273
274
275
276
277
278
279
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
  R_INT_PTR(out)[0] = static_cast<int>(nrow);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
280
281
282
LGBM_SE LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
283
284
285
286
287
288
289
290
291
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
  R_INT_PTR(out)[0] = static_cast<int>(nfeature);
  R_API_END();
}

// --- start Booster interfaces

Guolin Ke's avatar
Guolin Ke committed
292
293
LGBM_SE LGBM_BoosterFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
298
299
300
301
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
302
303
304
305
LGBM_SE LGBM_BoosterCreate_R(LGBM_SE train_data,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
306
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
307
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
308
309
310
311
312
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
313
314
315
LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
316
317
318

  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
319
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

325
326
327
328
329
330
LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
  LGBM_SE out,
  LGBM_SE call_state) {

  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
331
  BoosterHandle handle = nullptr;
332
333
334
335
336
  CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
337
338
339
LGBM_SE LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
344
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
345
346
347
LGBM_SE LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
353
354
355
LGBM_SE LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
356
357
358
359
360
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
361
362
363
LGBM_SE LGBM_BoosterResetParameter_R(LGBM_SE handle,
  LGBM_SE parameters,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
364
365
366
367
368
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
369
370
371
LGBM_SE LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
372
373
374
375
376
377
378
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
  R_INT_PTR(out)[0] = static_cast<int>(num_class);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
379
380
LGBM_SE LGBM_BoosterUpdateOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
381
382
383
384
385
386
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
391
LGBM_SE LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
  LGBM_SE grad,
  LGBM_SE hess,
  LGBM_SE len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
392
393
394
395
396
397
398
399
400
401
402
403
404
  int is_finished = 0;
  R_API_BEGIN();
  int int_len = R_AS_INT(len);
  std::vector<float> tgrad(int_len), thess(int_len);
#pragma omp parallel for schedule(static)
  for (int j = 0; j < int_len; ++j) {
    tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
    thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
405
406
LGBM_SE LGBM_BoosterRollbackOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
407
408
409
410
411
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
412
413
414
LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
415
416
417
418
419
420
421
422

  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
  R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
427
LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE eval_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441

  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(128);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
  CHECK(out_len == len);
  auto merge_names = Common::Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
442
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
443
444
445
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
446
447
448
449
LGBM_SE LGBM_BoosterGetEval_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
450
451
452
453
454
455
456
457
458
459
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  double* ptr_ret = R_REAL_PTR(out_result);
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  CHECK(out_len == len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
LGBM_SE LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
468
469
470
  R_API_BEGIN();
  int64_t len;
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
  R_INT_PTR(out)[0] = static_cast<int>(len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
471
472
473
474
LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478
479
480
481
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
482
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx) {
Guolin Ke's avatar
Guolin Ke committed
483
484
485
486
487
488
489
490
491
492
  int pred_type = C_API_PREDICT_NORMAL;
  if (R_AS_INT(is_rawscore)) {
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
  if (R_AS_INT(is_leafidx)) {
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
  return pred_type;
}

Guolin Ke's avatar
Guolin Ke committed
493
494
495
496
497
498
LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
  LGBM_SE data_filename,
  LGBM_SE data_has_header,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
499
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
500
501
  LGBM_SE result_filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
502
503
504
  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
505
    R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
    R_CHAR_PTR(result_filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
514
515
516
LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
  LGBM_SE out_len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
517
518
519
520
521
522
523
524
525
  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);
  int64_t len = 0;
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
    pred_type, R_AS_INT(num_iteration), &len));
  R_INT_PTR(out_len)[0] = static_cast<int>(len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
526
527
528
529
530
531
532
533
534
535
LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
  LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
536
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
537
538
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);

  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = R_AS_INT(num_indptr);
  int64_t ndata = R_AS_INT(nelem);
  int64_t nrow = R_AS_INT(num_row);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
555
    nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
556
557
558
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
559
560
561
562
563
564
565
LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
  LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
  LGBM_SE num_iteration,
566
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
567
568
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
569
570
571
572
573
574
575
576
577
578
579
580

  R_API_BEGIN();
  int pred_type = GetPredictType(is_rawscore, is_leafidx);

  int32_t nrow = R_AS_INT(num_row);
  int32_t ncol = R_AS_INT(num_col);

  double* p_mat = R_REAL_PTR(data);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
581
    pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
582
583
584
585

  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
586
587
588
589
LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
590
591
592
593
594
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
  R_API_END();
}

595
596
597
598
599
600
601
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
  R_API_BEGIN();
602
  int64_t out_len = 0;
603
604
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
605
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
606
607
608
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
609
610
611
612
613
614
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
615
  R_API_BEGIN();
616
  int64_t out_len = 0;
Guolin Ke's avatar
Guolin Ke committed
617
618
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
619
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
Guolin Ke's avatar
Guolin Ke committed
620
621
  R_API_END();
}