lightgbm_R.cpp 19 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/lightgbm_R.h>

7
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

Guolin Ke's avatar
Guolin Ke committed
12
13
#include <string>
#include <cstdio>
14
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
15
#include <memory>
16
17
#include <utility>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
27
28
29
30

#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
  catch(std::exception& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.what()); return call_state;} \
  catch(std::string& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.c_str()); return call_state; } \
  catch(...) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError("unknown exception"); return call_state;} \
  return call_state;

#define CHECK_CALL(x) \
  if ((x) != 0) { \
Guolin Ke's avatar
Guolin Ke committed
31
32
    R_INT_PTR(call_state)[0] = -1;\
    return call_state;\
Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
  }

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
37
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len, size_t str_len) {
Guolin Ke's avatar
Guolin Ke committed
38
  if (str_len > INT32_MAX) {
39
    Log::Fatal("Don't support large string in R-package");
Guolin Ke's avatar
Guolin Ke committed
40
41
  }
  R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
Guolin Ke's avatar
Guolin Ke committed
42
  if (R_AS_INT(buf_len) < static_cast<int>(str_len)) { return dest; }
Guolin Ke's avatar
Guolin Ke committed
43
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
44
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
45
46
47
  return dest;
}

Guolin Ke's avatar
Guolin Ke committed
48
LGBM_SE LGBM_GetLastError_R(LGBM_SE buf_len, LGBM_SE actual_len, LGBM_SE err_msg) {
Guolin Ke's avatar
Guolin Ke committed
49
  return EncodeChar(err_msg, LGBM_GetLastError(), buf_len, actual_len, std::strlen(LGBM_GetLastError()) + 1);
Guolin Ke's avatar
Guolin Ke committed
50
51
}

Guolin Ke's avatar
Guolin Ke committed
52
53
54
55
56
LGBM_SE LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
57
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
58
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
64
  CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
65
66
67
68
69
70
71
72
73
74
LGBM_SE LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
  R_API_BEGIN();
  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
  int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
  int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
Guolin Ke's avatar
Guolin Ke committed
83
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
89
90
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
    nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
96
97
LGBM_SE LGBM_DatasetCreateFromMat_R(LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
  R_API_BEGIN();
  int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
  int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
  double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
102
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
107
108
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
    R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
LGBM_SE LGBM_DatasetGetSubset_R(LGBM_SE handle,
  LGBM_SE used_row_indices,
  LGBM_SE len_used_row_indices,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
  R_API_BEGIN();
  int len = R_AS_INT(len_used_row_indices);
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
Guolin Ke's avatar
Guolin Ke committed
119
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
120
121
122
  for (int i = 0; i < len; ++i) {
    idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
  }
Guolin Ke's avatar
Guolin Ke committed
123
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
128
129
130
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
    idxvec.data(), len, R_CHAR_PTR(parameters),
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
131
132
133
LGBM_SE LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
134
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
135
  auto vec_names = Common::Split(R_CHAR_PTR(feature_names), '\t');
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140
141
142
143
144
145
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(256);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
    ptr_names.data(), &out_len));
  CHECK(len == out_len);
  auto merge_str = Common::Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
165
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
166
167
168
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
169
170
171
LGBM_SE LGBM_DatasetSaveBinary_R(LGBM_SE handle,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
    R_CHAR_PTR(filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
178
179
LGBM_SE LGBM_DatasetFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
186
187
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
LGBM_SE LGBM_DatasetSetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE num_element,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
193
194
195
196
197
  R_API_BEGIN();
  int len = static_cast<int>(R_AS_INT(num_element));
  const char* name = R_CHAR_PTR(field_name);
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
198
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
203
  } else if (!strcmp("init_score", name)) {
Guolin Ke's avatar
Guolin Ke committed
204
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
205
206
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
207
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
208
209
210
211
212
213
214
215
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
224
225
226
227
228
229
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
230
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
231
232
233
    for (int i = 0; i < out_len - 1; ++i) {
      R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
234
235
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
236
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
237
238
239
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
240
241
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
242
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
247
248
249
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
258
259
260
261
262
263
264
265
266
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
  R_INT_PTR(out)[0] = static_cast<int>(out_len);
  R_API_END();
}

267
268
LGBM_SE LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params,
  LGBM_SE new_params,
269
270
  LGBM_SE call_state) {
  R_API_BEGIN();
271
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(R_CHAR_PTR(old_params), R_CHAR_PTR(new_params)));
272
273
274
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
275
276
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
281
282
283
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
  R_INT_PTR(out)[0] = static_cast<int>(nrow);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
284
285
286
LGBM_SE LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
287
288
289
290
291
292
293
294
295
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
  R_INT_PTR(out)[0] = static_cast<int>(nfeature);
  R_API_END();
}

// --- start Booster interfaces

Guolin Ke's avatar
Guolin Ke committed
296
297
LGBM_SE LGBM_BoosterFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
303
304
305
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
LGBM_SE LGBM_BoosterCreate_R(LGBM_SE train_data,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
310
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
311
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
312
313
314
315
316
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
317
318
319
LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
320
321
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
322
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
323
324
325
326
327
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

328
329
330
331
332
LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
  LGBM_SE out,
  LGBM_SE call_state) {
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
333
  BoosterHandle handle = nullptr;
334
335
336
337
338
  CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
339
340
341
LGBM_SE LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
342
343
344
345
346
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
347
348
349
LGBM_SE LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
350
351
352
353
354
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
355
356
357
LGBM_SE LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
358
359
360
361
362
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
363
364
365
LGBM_SE LGBM_BoosterResetParameter_R(LGBM_SE handle,
  LGBM_SE parameters,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
366
367
368
369
370
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
371
372
373
LGBM_SE LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
374
375
376
377
378
379
380
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
  R_INT_PTR(out)[0] = static_cast<int>(num_class);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
381
382
LGBM_SE LGBM_BoosterUpdateOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
383
384
385
386
387
388
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
389
390
391
392
393
LGBM_SE LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
  LGBM_SE grad,
  LGBM_SE hess,
  LGBM_SE len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
394
395
396
397
  int is_finished = 0;
  R_API_BEGIN();
  int int_len = R_AS_INT(len);
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
398
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
399
400
401
402
403
404
405
406
  for (int j = 0; j < int_len; ++j) {
    tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
    thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
407
408
LGBM_SE LGBM_BoosterRollbackOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
414
415
416
LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
417
418
419
420
421
422
423
  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
  R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
  R_API_END();
}

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
LGBM_SE LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result,
  LGBM_SE call_state) {
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

LGBM_SE LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result,
  LGBM_SE call_state) {
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
442
443
444
445
446
LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE eval_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
447
448
449
450
451
452
453
454
455
456
457
458
459
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(128);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
  CHECK(out_len == len);
  auto merge_names = Common::Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
460
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
461
462
463
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
LGBM_SE LGBM_BoosterGetEval_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
468
469
470
471
472
473
474
475
476
477
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  double* ptr_ret = R_REAL_PTR(out_result);
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  CHECK(out_len == len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
478
479
480
481
LGBM_SE LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
482
483
484
  R_API_BEGIN();
  int64_t len;
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
Guolin Ke's avatar
Guolin Ke committed
485
  R_INT64_PTR(out)[0] = len;
Guolin Ke's avatar
Guolin Ke committed
486
487
488
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
489
490
491
492
LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
493
494
495
496
497
498
499
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  R_API_END();
}

500
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
501
502
503
504
505
506
507
  int pred_type = C_API_PREDICT_NORMAL;
  if (R_AS_INT(is_rawscore)) {
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
  if (R_AS_INT(is_leafidx)) {
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
508
509
510
  if (R_AS_INT(is_predcontrib)) {
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
511
512
513
  return pred_type;
}

Guolin Ke's avatar
Guolin Ke committed
514
515
516
517
518
LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
  LGBM_SE data_filename,
  LGBM_SE data_has_header,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
519
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
520
  LGBM_SE num_iteration,
521
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
522
523
  LGBM_SE result_filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
524
  R_API_BEGIN();
525
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
526
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
527
    R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
Guolin Ke's avatar
Guolin Ke committed
528
529
530
531
    R_CHAR_PTR(result_filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
532
533
534
535
LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
536
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
537
538
539
  LGBM_SE num_iteration,
  LGBM_SE out_len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
540
  R_API_BEGIN();
541
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
542
543
544
545
546
547
548
  int64_t len = 0;
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
    pred_type, R_AS_INT(num_iteration), &len));
  R_INT_PTR(out_len)[0] = static_cast<int>(len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
549
550
551
552
553
554
555
556
557
LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
  LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
558
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
559
  LGBM_SE num_iteration,
560
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
561
562
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
563
  R_API_BEGIN();
564
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
565
566
567
568
569
570
571
572
573
574
575
576
577

  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = R_AS_INT(num_indptr);
  int64_t ndata = R_AS_INT(nelem);
  int64_t nrow = R_AS_INT(num_row);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
578
    nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
579
580
581
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
582
583
584
585
586
587
LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
  LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
588
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
589
  LGBM_SE num_iteration,
590
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
591
592
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
593
  R_API_BEGIN();
594
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
595
596
597
598
599
600
601
602
603

  int32_t nrow = R_AS_INT(num_row);
  int32_t ncol = R_AS_INT(num_col);

  double* p_mat = R_REAL_PTR(data);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
604
    pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
605
606
607
608

  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
609
610
611
612
LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
613
  R_API_BEGIN();
614
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
Guolin Ke's avatar
Guolin Ke committed
615
616
617
  R_API_END();
}

618
619
620
621
622
623
624
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
  R_API_BEGIN();
625
  int64_t out_len = 0;
626
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
627
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
628
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
629
630
631
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
632
633
634
635
636
637
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
638
  R_API_BEGIN();
639
  int64_t out_len = 0;
Guolin Ke's avatar
Guolin Ke committed
640
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
641
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
642
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
Guolin Ke's avatar
Guolin Ke committed
643
644
  R_API_END();
}