lightgbm_R.cpp 19 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/lightgbm_R.h>

7
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>

Guolin Ke's avatar
Guolin Ke committed
12
13
#include <string>
#include <cstdio>
14
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
15
#include <memory>
16
17
#include <utility>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
27
28
29
30

#define COL_MAJOR (0)

#define R_API_BEGIN() \
  try {
#define R_API_END() } \
  catch(std::exception& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.what()); return call_state;} \
  catch(std::string& ex) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError(ex.c_str()); return call_state; } \
  catch(...) { R_INT_PTR(call_state)[0] = -1; LGBM_SetLastError("unknown exception"); return call_state;} \
  return call_state;

#define CHECK_CALL(x) \
  if ((x) != 0) { \
Guolin Ke's avatar
Guolin Ke committed
31
32
    R_INT_PTR(call_state)[0] = -1;\
    return call_state;\
Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
  }

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
37
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len, size_t str_len) {
Guolin Ke's avatar
Guolin Ke committed
38
  if (str_len > INT32_MAX) {
39
    Log::Fatal("Don't support large string in R-package");
Guolin Ke's avatar
Guolin Ke committed
40
41
  }
  R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
42
43
44
  if (R_AS_INT(buf_len) < static_cast<int>(str_len)) {
    return dest;
  }
Guolin Ke's avatar
Guolin Ke committed
45
  auto ptr = R_CHAR_PTR(dest);
Guolin Ke's avatar
Guolin Ke committed
46
  std::memcpy(ptr, src, str_len);
Guolin Ke's avatar
Guolin Ke committed
47
48
49
  return dest;
}

Guolin Ke's avatar
Guolin Ke committed
50
LGBM_SE LGBM_GetLastError_R(LGBM_SE buf_len, LGBM_SE actual_len, LGBM_SE err_msg) {
Guolin Ke's avatar
Guolin Ke committed
51
  return EncodeChar(err_msg, LGBM_GetLastError(), buf_len, actual_len, std::strlen(LGBM_GetLastError()) + 1);
Guolin Ke's avatar
Guolin Ke committed
52
53
}

Guolin Ke's avatar
Guolin Ke committed
54
55
56
57
58
LGBM_SE LGBM_DatasetCreateFromFile_R(LGBM_SE filename,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
59
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
60
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
61
62
63
64
65
66
  CHECK_CALL(LGBM_DatasetCreateFromFile(R_CHAR_PTR(filename), R_CHAR_PTR(parameters),
    R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
67
68
69
70
71
72
73
74
75
76
LGBM_SE LGBM_DatasetCreateFromCSC_R(LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
77
78
79
80
81
82
83
84
  R_API_BEGIN();
  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = static_cast<int64_t>(R_AS_INT(num_indptr));
  int64_t ndata = static_cast<int64_t>(R_AS_INT(nelem));
  int64_t nrow = static_cast<int64_t>(R_AS_INT(num_row));
Guolin Ke's avatar
Guolin Ke committed
85
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
90
91
92
  CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
    nrow, R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99
LGBM_SE LGBM_DatasetCreateFromMat_R(LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE parameters,
  LGBM_SE reference,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
  R_API_BEGIN();
  int32_t nrow = static_cast<int32_t>(R_AS_INT(num_row));
  int32_t ncol = static_cast<int32_t>(R_AS_INT(num_col));
  double* p_mat = R_REAL_PTR(data);
Guolin Ke's avatar
Guolin Ke committed
104
  DatasetHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
  CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
    R_CHAR_PTR(parameters), R_GET_PTR(reference), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
LGBM_SE LGBM_DatasetGetSubset_R(LGBM_SE handle,
  LGBM_SE used_row_indices,
  LGBM_SE len_used_row_indices,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
  R_API_BEGIN();
  int len = R_AS_INT(len_used_row_indices);
  std::vector<int> idxvec(len);
  // convert from one-based to  zero-based index
Guolin Ke's avatar
Guolin Ke committed
121
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
122
123
124
  for (int i = 0; i < len; ++i) {
    idxvec[i] = R_INT_PTR(used_row_indices)[i] - 1;
  }
Guolin Ke's avatar
Guolin Ke committed
125
  DatasetHandle res = nullptr;
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
  CHECK_CALL(LGBM_DatasetGetSubset(R_GET_PTR(handle),
    idxvec.data(), len, R_CHAR_PTR(parameters),
    &res));
  R_SET_PTR(out, res);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
133
134
135
LGBM_SE LGBM_DatasetSetFeatureNames_R(LGBM_SE handle,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
136
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
137
  auto vec_names = Common::Split(R_CHAR_PTR(feature_names), '\t');
Guolin Ke's avatar
Guolin Ke committed
138
139
140
141
142
143
144
145
146
147
  std::vector<const char*> vec_sptr;
  int len = static_cast<int>(vec_names.size());
  for (int i = 0; i < len; ++i) {
    vec_sptr.push_back(vec_names[i].c_str());
  }
  CHECK_CALL(LGBM_DatasetSetFeatureNames(R_GET_PTR(handle),
    vec_sptr.data(), len));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
148
149
150
151
152
LGBM_SE LGBM_DatasetGetFeatureNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE feature_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
158
159
160
161
162
163
164
  R_API_BEGIN();
  int len = 0;
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(256);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_DatasetGetFeatureNames(R_GET_PTR(handle),
    ptr_names.data(), &out_len));
Nikita Titov's avatar
Nikita Titov committed
165
  CHECK_EQ(len, out_len);
Guolin Ke's avatar
Guolin Ke committed
166
  auto merge_str = Common::Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
167
  EncodeChar(feature_names, merge_str.c_str(), buf_len, actual_len, merge_str.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
168
169
170
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
171
172
173
LGBM_SE LGBM_DatasetSaveBinary_R(LGBM_SE handle,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
174
175
176
177
178
179
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetSaveBinary(R_GET_PTR(handle),
    R_CHAR_PTR(filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
180
181
LGBM_SE LGBM_DatasetFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
188
189
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_DatasetFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
194
LGBM_SE LGBM_DatasetSetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE num_element,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
195
196
197
198
199
  R_API_BEGIN();
  int len = static_cast<int>(R_AS_INT(num_element));
  const char* name = R_CHAR_PTR(field_name);
  if (!strcmp("group", name) || !strcmp("query", name)) {
    std::vector<int32_t> vec(len);
Guolin Ke's avatar
Guolin Ke committed
200
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<int32_t>(R_INT_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_INT32));
205
  } else if (!strcmp("init_score", name)) {
Guolin Ke's avatar
Guolin Ke committed
206
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, R_REAL_PTR(field_data), len, C_API_DTYPE_FLOAT64));
Guolin Ke's avatar
Guolin Ke committed
207
208
  } else {
    std::vector<float> vec(len);
Guolin Ke's avatar
Guolin Ke committed
209
#pragma omp parallel for schedule(static, 512) if (len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
210
211
212
213
214
215
216
217
    for (int i = 0; i < len; ++i) {
      vec[i] = static_cast<float>(R_REAL_PTR(field_data)[i]);
    }
    CHECK_CALL(LGBM_DatasetSetField(R_GET_PTR(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE field_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
227
228
229
230
231
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));

  if (!strcmp("group", name) || !strcmp("query", name)) {
    auto p_data = reinterpret_cast<const int32_t*>(res);
    // convert from boundaries to size
Guolin Ke's avatar
Guolin Ke committed
232
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
233
234
235
    for (int i = 0; i < out_len - 1; ++i) {
      R_INT_PTR(field_data)[i] = p_data[i + 1] - p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
236
237
  } else if (!strcmp("init_score", name)) {
    auto p_data = reinterpret_cast<const double*>(res);
Guolin Ke's avatar
Guolin Ke committed
238
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
239
240
241
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
Guolin Ke's avatar
Guolin Ke committed
242
243
  } else {
    auto p_data = reinterpret_cast<const float*>(res);
Guolin Ke's avatar
Guolin Ke committed
244
#pragma omp parallel for schedule(static, 512) if (out_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
245
246
247
248
249
250
251
    for (int i = 0; i < out_len; ++i) {
      R_REAL_PTR(field_data)[i] = p_data[i];
    }
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
252
253
254
255
LGBM_SE LGBM_DatasetGetFieldSize_R(LGBM_SE handle,
  LGBM_SE field_name,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
256
257
258
259
260
261
262
263
264
265
266
267
268
  R_API_BEGIN();
  const char* name = R_CHAR_PTR(field_name);
  int out_len = 0;
  int out_type = 0;
  const void* res;
  CHECK_CALL(LGBM_DatasetGetField(R_GET_PTR(handle), name, &out_len, &res, &out_type));
  if (!strcmp("group", name) || !strcmp("query", name)) {
    out_len -= 1;
  }
  R_INT_PTR(out)[0] = static_cast<int>(out_len);
  R_API_END();
}

269
270
LGBM_SE LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params,
  LGBM_SE new_params,
271
272
  LGBM_SE call_state) {
  R_API_BEGIN();
273
  CHECK_CALL(LGBM_DatasetUpdateParamChecking(R_CHAR_PTR(old_params), R_CHAR_PTR(new_params)));
274
275
276
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
277
278
LGBM_SE LGBM_DatasetGetNumData_R(LGBM_SE handle, LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
279
280
281
282
283
284
285
  int nrow;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumData(R_GET_PTR(handle), &nrow));
  R_INT_PTR(out)[0] = static_cast<int>(nrow);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
286
287
288
LGBM_SE LGBM_DatasetGetNumFeature_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
289
290
291
292
293
294
295
296
297
  int nfeature;
  R_API_BEGIN();
  CHECK_CALL(LGBM_DatasetGetNumFeature(R_GET_PTR(handle), &nfeature));
  R_INT_PTR(out)[0] = static_cast<int>(nfeature);
  R_API_END();
}

// --- start Booster interfaces

Guolin Ke's avatar
Guolin Ke committed
298
299
LGBM_SE LGBM_BoosterFree_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
300
301
302
303
304
305
306
307
  R_API_BEGIN();
  if (R_GET_PTR(handle) != nullptr) {
    CHECK_CALL(LGBM_BoosterFree(R_GET_PTR(handle)));
    R_SET_PTR(handle, nullptr);
  }
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
308
309
310
311
LGBM_SE LGBM_BoosterCreate_R(LGBM_SE train_data,
  LGBM_SE parameters,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
312
  R_API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
313
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318
  CHECK_CALL(LGBM_BoosterCreate(R_GET_PTR(train_data), R_CHAR_PTR(parameters), &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
319
320
321
LGBM_SE LGBM_BoosterCreateFromModelfile_R(LGBM_SE filename,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
322
323
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
324
  BoosterHandle handle = nullptr;
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328
329
  CHECK_CALL(LGBM_BoosterCreateFromModelfile(R_CHAR_PTR(filename), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

330
331
332
333
334
LGBM_SE LGBM_BoosterLoadModelFromString_R(LGBM_SE model_str,
  LGBM_SE out,
  LGBM_SE call_state) {
  R_API_BEGIN();
  int out_num_iterations = 0;
Guolin Ke's avatar
Guolin Ke committed
335
  BoosterHandle handle = nullptr;
336
337
338
339
340
  CHECK_CALL(LGBM_BoosterLoadModelFromString(R_CHAR_PTR(model_str), &out_num_iterations, &handle));
  R_SET_PTR(out, handle);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
341
342
343
LGBM_SE LGBM_BoosterMerge_R(LGBM_SE handle,
  LGBM_SE other_handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
344
345
346
347
348
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterMerge(R_GET_PTR(handle), R_GET_PTR(other_handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
349
350
351
LGBM_SE LGBM_BoosterAddValidData_R(LGBM_SE handle,
  LGBM_SE valid_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
356
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterAddValidData(R_GET_PTR(handle), R_GET_PTR(valid_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
357
358
359
LGBM_SE LGBM_BoosterResetTrainingData_R(LGBM_SE handle,
  LGBM_SE train_data,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
360
361
362
363
364
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetTrainingData(R_GET_PTR(handle), R_GET_PTR(train_data)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
365
366
367
LGBM_SE LGBM_BoosterResetParameter_R(LGBM_SE handle,
  LGBM_SE parameters,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
368
369
370
371
372
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterResetParameter(R_GET_PTR(handle), R_CHAR_PTR(parameters)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
373
374
375
LGBM_SE LGBM_BoosterGetNumClasses_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
379
380
381
382
  int num_class;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetNumClasses(R_GET_PTR(handle), &num_class));
  R_INT_PTR(out)[0] = static_cast<int>(num_class);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
383
384
LGBM_SE LGBM_BoosterUpdateOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
389
390
  int is_finished = 0;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterUpdateOneIter(R_GET_PTR(handle), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
395
LGBM_SE LGBM_BoosterUpdateOneIterCustom_R(LGBM_SE handle,
  LGBM_SE grad,
  LGBM_SE hess,
  LGBM_SE len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
396
397
398
399
  int is_finished = 0;
  R_API_BEGIN();
  int int_len = R_AS_INT(len);
  std::vector<float> tgrad(int_len), thess(int_len);
Guolin Ke's avatar
Guolin Ke committed
400
#pragma omp parallel for schedule(static, 512) if (int_len >= 1024)
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
405
406
407
408
  for (int j = 0; j < int_len; ++j) {
    tgrad[j] = static_cast<float>(R_REAL_PTR(grad)[j]);
    thess[j] = static_cast<float>(R_REAL_PTR(hess)[j]);
  }
  CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_GET_PTR(handle), tgrad.data(), thess.data(), &is_finished));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
409
410
LGBM_SE LGBM_BoosterRollbackOneIter_R(LGBM_SE handle,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
415
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterRollbackOneIter(R_GET_PTR(handle)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
416
417
418
LGBM_SE LGBM_BoosterGetCurrentIteration_R(LGBM_SE handle,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
419
420
421
422
423
424
425
  int out_iteration;
  R_API_BEGIN();
  CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_GET_PTR(handle), &out_iteration));
  R_INT_PTR(out)[0] = static_cast<int>(out_iteration);
  R_API_END();
}

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
LGBM_SE LGBM_BoosterGetUpperBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result,
  LGBM_SE call_state) {
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

LGBM_SE LGBM_BoosterGetLowerBoundValue_R(LGBM_SE handle,
  LGBM_SE out_result,
  LGBM_SE call_state) {
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_GET_PTR(handle), ptr_ret));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
444
445
446
447
448
LGBM_SE LGBM_BoosterGetEvalNames_R(LGBM_SE handle,
  LGBM_SE buf_len,
  LGBM_SE actual_len,
  LGBM_SE eval_names,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
449
450
451
452
453
454
455
456
457
458
459
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  std::vector<std::vector<char>> names(len);
  std::vector<char*> ptr_names(len);
  for (int i = 0; i < len; ++i) {
    names[i].resize(128);
    ptr_names[i] = names[i].data();
  }
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEvalNames(R_GET_PTR(handle), &out_len, ptr_names.data()));
Nikita Titov's avatar
Nikita Titov committed
460
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
461
  auto merge_names = Common::Join<char*>(ptr_names, "\t");
Guolin Ke's avatar
Guolin Ke committed
462
  EncodeChar(eval_names, merge_names.c_str(), buf_len, actual_len, merge_names.size() + 1);
Guolin Ke's avatar
Guolin Ke committed
463
464
465
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
466
467
468
469
LGBM_SE LGBM_BoosterGetEval_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
470
471
472
473
474
475
  R_API_BEGIN();
  int len;
  CHECK_CALL(LGBM_BoosterGetEvalCounts(R_GET_PTR(handle), &len));
  double* ptr_ret = R_REAL_PTR(out_result);
  int out_len;
  CHECK_CALL(LGBM_BoosterGetEval(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
Nikita Titov's avatar
Nikita Titov committed
476
  CHECK_EQ(out_len, len);
Guolin Ke's avatar
Guolin Ke committed
477
478
479
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
480
481
482
483
LGBM_SE LGBM_BoosterGetNumPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
484
485
486
  R_API_BEGIN();
  int64_t len;
  CHECK_CALL(LGBM_BoosterGetNumPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &len));
Guolin Ke's avatar
Guolin Ke committed
487
  R_INT64_PTR(out)[0] = len;
Guolin Ke's avatar
Guolin Ke committed
488
489
490
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
491
492
493
494
LGBM_SE LGBM_BoosterGetPredict_R(LGBM_SE handle,
  LGBM_SE data_idx,
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
495
496
497
498
499
500
501
  R_API_BEGIN();
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterGetPredict(R_GET_PTR(handle), R_AS_INT(data_idx), &out_len, ptr_ret));
  R_API_END();
}

502
int GetPredictType(LGBM_SE is_rawscore, LGBM_SE is_leafidx, LGBM_SE is_predcontrib) {
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
507
508
509
  int pred_type = C_API_PREDICT_NORMAL;
  if (R_AS_INT(is_rawscore)) {
    pred_type = C_API_PREDICT_RAW_SCORE;
  }
  if (R_AS_INT(is_leafidx)) {
    pred_type = C_API_PREDICT_LEAF_INDEX;
  }
510
511
512
  if (R_AS_INT(is_predcontrib)) {
    pred_type = C_API_PREDICT_CONTRIB;
  }
Guolin Ke's avatar
Guolin Ke committed
513
514
515
  return pred_type;
}

Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
  LGBM_SE data_filename,
  LGBM_SE data_has_header,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
521
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
522
  LGBM_SE num_iteration,
523
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
524
525
  LGBM_SE result_filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
526
  R_API_BEGIN();
527
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
528
  CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
529
    R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
    R_CHAR_PTR(result_filename)));
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
534
535
536
537
LGBM_SE LGBM_BoosterCalcNumPredict_R(LGBM_SE handle,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
538
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
539
540
541
  LGBM_SE num_iteration,
  LGBM_SE out_len,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
542
  R_API_BEGIN();
543
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
544
545
546
547
548
549
550
  int64_t len = 0;
  CHECK_CALL(LGBM_BoosterCalcNumPredict(R_GET_PTR(handle), R_AS_INT(num_row),
    pred_type, R_AS_INT(num_iteration), &len));
  R_INT_PTR(out_len)[0] = static_cast<int>(len);
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
551
552
553
554
555
556
557
558
559
LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
  LGBM_SE indptr,
  LGBM_SE indices,
  LGBM_SE data,
  LGBM_SE num_indptr,
  LGBM_SE nelem,
  LGBM_SE num_row,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
560
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
561
  LGBM_SE num_iteration,
562
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
563
564
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
565
  R_API_BEGIN();
566
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
567
568
569
570
571
572
573
574
575
576
577
578
579

  const int* p_indptr = R_INT_PTR(indptr);
  const int* p_indices = R_INT_PTR(indices);
  const double* p_data = R_REAL_PTR(data);

  int64_t nindptr = R_AS_INT(num_indptr);
  int64_t ndata = R_AS_INT(nelem);
  int64_t nrow = R_AS_INT(num_row);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
    p_indptr, C_API_DTYPE_INT32, p_indices,
    p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
580
    nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
581
582
583
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
584
585
586
587
588
589
LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
  LGBM_SE data,
  LGBM_SE num_row,
  LGBM_SE num_col,
  LGBM_SE is_rawscore,
  LGBM_SE is_leafidx,
590
  LGBM_SE is_predcontrib,
Guolin Ke's avatar
Guolin Ke committed
591
  LGBM_SE num_iteration,
592
  LGBM_SE parameter,
Guolin Ke's avatar
Guolin Ke committed
593
594
  LGBM_SE out_result,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
595
  R_API_BEGIN();
596
  int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
Guolin Ke's avatar
Guolin Ke committed
597
598
599
600
601
602
603
604
605

  int32_t nrow = R_AS_INT(num_row);
  int32_t ncol = R_AS_INT(num_col);

  double* p_mat = R_REAL_PTR(data);
  double* ptr_ret = R_REAL_PTR(out_result);
  int64_t out_len;
  CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
    p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
606
    pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
Guolin Ke's avatar
Guolin Ke committed
607
608
609
610

  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE filename,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
615
  R_API_BEGIN();
616
  CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
Guolin Ke's avatar
Guolin Ke committed
617
618
619
  R_API_END();
}

620
621
622
623
624
625
626
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
  R_API_BEGIN();
627
  int64_t out_len = 0;
628
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
629
  CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
630
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
631
632
633
  R_API_END();
}

Guolin Ke's avatar
Guolin Ke committed
634
635
636
637
638
639
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
  LGBM_SE num_iteration,
  LGBM_SE buffer_len,
  LGBM_SE actual_len,
  LGBM_SE out_str,
  LGBM_SE call_state) {
Guolin Ke's avatar
Guolin Ke committed
640
  R_API_BEGIN();
641
  int64_t out_len = 0;
Guolin Ke's avatar
Guolin Ke committed
642
  std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
643
  CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
Guolin Ke's avatar
Guolin Ke committed
644
  EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
Guolin Ke's avatar
Guolin Ke committed
645
646
  R_API_END();
}