c_api.cpp 38.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

Guolin Ke's avatar
Guolin Ke committed
34
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
35
36
37
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
38
39
40
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
41
42
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
43
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
44
45
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
46

wxchan's avatar
wxchan committed
47
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
48

Guolin Ke's avatar
Guolin Ke committed
49
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
50
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
51
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
52
53

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
54
55
56
57
58
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
59
60
61
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
62

Guolin Ke's avatar
Guolin Ke committed
63
  }
64

wxchan's avatar
wxchan committed
65
66
67
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
90
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
91
92
93
94
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
95
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
96
97
98
99
100
101
102
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
106
107

    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
112
113
114
115
116
117
118
119
120
121
122

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
123
    }
Guolin Ke's avatar
Guolin Ke committed
124
125
126

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
127

wxchan's avatar
wxchan committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
143

144
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
145
    std::lock_guard<std::mutex> lock(mutex_);
146
147
148
149
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
150
    std::lock_guard<std::mutex> lock(mutex_);
151
152
153
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
154
155
156
157
158
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
159
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
160
161
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
162
163
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
164
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
165
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
166
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
167
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
168
169
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
170
    }
Guolin Ke's avatar
Guolin Ke committed
171
172
173
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
174
175
  }

Guolin Ke's avatar
Guolin Ke committed
176
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
177
178
179
180
181
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
182
  }
183

184
185
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
186
  }
187

Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
193
194
195
196
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
197
198
199
200
201
202
203
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
204

wxchan's avatar
wxchan committed
205
206
207
208
209
210
211
212
213
214
215
216
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
217

Guolin Ke's avatar
Guolin Ke committed
218
private:
219

wxchan's avatar
wxchan committed
220
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
221
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
222
223
224
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
225
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
226
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
227
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
228
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
229
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
230
231
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
232
233
234
};

}
Guolin Ke's avatar
Guolin Ke committed
235
236
237

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

270
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
271
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

274
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
275
  const char* parameters,
276
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
277
  DatasetHandle* out) {
278
  API_BEGIN();
wxchan's avatar
wxchan committed
279
280
281
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
282
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
283
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
284
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
285
  } else {
Guolin Ke's avatar
Guolin Ke committed
286
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
287
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
288
  }
289
  API_END();
Guolin Ke's avatar
Guolin Ke committed
290
291
}

292
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
293
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
298
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
299
  DatasetHandle* out) {
300
  API_BEGIN();
wxchan's avatar
wxchan committed
301
302
303
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
304
  std::unique_ptr<Dataset> ret;
305
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
306
307
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
308
309
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
310
    auto sample_indices = rand.Sample(nrow, sample_cnt);
311
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
312
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
313
      auto idx = sample_indices[i];
314
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
315
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
316
317
318
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
319
320
      }
    }
321
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
322
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
323
  } else {
324
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
325
    ret->CopyFeatureMapperFrom(
326
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
327
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
333
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
334
335
336
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
337
  *out = ret.release();
338
  API_END();
339
340
}

341
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
342
  int indptr_type,
343
344
  const int32_t* indices,
  const void* data,
345
346
347
348
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
349
  const char* parameters,
350
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
351
  DatasetHandle* out) {
352
  API_BEGIN();
wxchan's avatar
wxchan committed
353
354
355
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
356
  std::unique_ptr<Dataset> ret;
357
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
358
359
360
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
361
362
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
363
364
365
366
367
368
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
369
370
371
372
373
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
374
          }
375
376
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
377
378
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
379
380
381
        }
      }
    }
382
    CHECK(num_col >= static_cast<int>(sample_values.size()));
383
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
384
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
385
  } else {
386
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
387
    ret->CopyFeatureMapperFrom(
388
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
389
      io_config.is_enable_sparse);
390
391
392
393
394
395
396
397
398
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
399
  *out = ret.release();
400
  API_END();
401
402
}

403
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
404
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
405
406
  const int32_t* indices,
  const void* data,
407
408
409
410
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
411
  const char* parameters,
412
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
413
  DatasetHandle* out) {
414
  API_BEGIN();
wxchan's avatar
wxchan committed
415
416
417
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
418
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
419
420
421
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
422
423
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
428
429
430
431
432
433
434
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
435
    }
436
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
437
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
438
  } else {
439
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
440
    ret->CopyFeatureMapperFrom(
441
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
442
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
443
444
445
446
447
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
448
449
450
451
452
453
454
455
456
457
458
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
459
460
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
461
  *out = ret.release();
462
  API_END();
Guolin Ke's avatar
Guolin Ke committed
463
464
}

465
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
466
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
467
468
469
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
470
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
471
472
473
474
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
475
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
476
477
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
478
      num_used_row_indices,
wxchan's avatar
wxchan committed
479
480
481
482
483
484
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

485
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
486
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
487
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
488
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
489
490
491
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
492
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
493
494
495
496
497
498
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

499
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
500
501
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
502
  int* num_feature_names) {
503
504
505
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
506
507
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
508
509
510
511
512
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

513
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
514
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
515
  delete reinterpret_cast<Dataset*>(handle);
516
  API_END();
517
518
}

519
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
520
  const char* filename) {
521
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
522
523
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
524
  API_END();
525
526
}

527
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
528
529
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
530
  int num_element,
531
  int type) {
532
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
533
  auto dataset = reinterpret_cast<Dataset*>(handle);
534
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
535
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
536
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
537
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
538
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
539
540
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
541
  }
542
543
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
544
545
}

546
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
547
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
548
  int* out_len,
549
550
  const void** out_ptr,
  int* out_type) {
551
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
552
  auto dataset = reinterpret_cast<Dataset*>(handle);
553
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
554
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
555
    *out_type = C_API_DTYPE_FLOAT32;
556
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
557
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
558
    *out_type = C_API_DTYPE_INT32;
559
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
560
561
562
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
563
  }
564
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
565
  if (*out_ptr == nullptr) { *out_len = 0; }
566
  API_END();
567
568
}

569
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
570
  int* out) {
571
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
572
573
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
574
  API_END();
575
576
}

577
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
578
  int* out) {
579
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
580
581
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
582
  API_END();
Guolin Ke's avatar
Guolin Ke committed
583
}
584
585
586

// ---- start of booster

587
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
588
589
  const char* parameters,
  BoosterHandle* out) {
590
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
591
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
592
593
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
594
  API_END();
595
596
}

597
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
598
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
599
  int* out_num_iterations,
600
  BoosterHandle* out) {
601
  API_BEGIN();
wxchan's avatar
wxchan committed
602
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
603
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
604
  *out = ret.release();
605
  API_END();
606
607
}

608
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
609
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
610
  delete reinterpret_cast<Booster*>(handle);
611
  API_END();
612
613
}

614
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
615
616
617
618
619
620
621
622
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

623
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
624
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
625
626
627
628
629
630
631
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

632
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
633
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
634
635
636
637
638
639
640
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

641
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
642
643
644
645
646
647
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

648
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
649
650
651
652
653
654
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

655
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
656
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
657
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
658
659
660
661
662
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
663
  API_END();
664
665
}

666
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
667
668
669
  const float* grad,
  const float* hess,
  int* is_finished) {
670
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
671
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
672
673
674
675
676
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
677
  API_END();
678
679
}

680
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
681
682
683
684
685
686
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

687
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
688
689
690
691
692
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
693

694
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
695
696
697
698
699
700
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

701
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
702
703
704
705
706
707
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

708
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
709
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
710
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
711
  double* out_results) {
712
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
713
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
714
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
715
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
716
  *out_len = static_cast<int>(result_buf.size());
717
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
718
    (out_results)[i] = static_cast<double>(result_buf[i]);
719
  }
720
  API_END();
721
722
}

723
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
724
725
726
727
728
729
730
731
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

732
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
733
  int data_idx,
734
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
735
  double* out_result) {
736
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
737
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
738
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
739
  API_END();
Guolin Ke's avatar
Guolin Ke committed
740
741
}

742
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
743
  const char* data_filename,
wxchan's avatar
wxchan committed
744
745
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
746
  int num_iteration,
747
  const char* result_filename) {
748
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
749
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
750
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
751
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
752
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
753
  API_END();
754
755
}

756
757
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
758
759
760
761
762
763
764
765
766
767
768
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

769
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
770
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
771
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
772
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
773
774
775
776
777
778
779
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

780
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
781
782
  const void* indptr,
  int indptr_type,
783
784
  const int32_t* indices,
  const void* data,
785
786
787
788
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
789
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
790
  int num_iteration,
wxchan's avatar
wxchan committed
791
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
792
  double* out_result) {
793
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
794
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
795
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
796
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
797
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
798
799
800
801
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
802
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
803
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
804
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
805
806
    }
  }
wxchan's avatar
wxchan committed
807
  *out_len = nrow * num_preb_in_one_row;
808
  API_END();
Guolin Ke's avatar
Guolin Ke committed
809
}
810

811
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
812
813
814
815
816
817
818
819
820
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
821
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
822
823
824
825
826
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
827
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

856
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
857
  const void* data,
858
  int data_type,
859
860
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
861
  int is_row_major,
862
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
863
  int num_iteration,
wxchan's avatar
wxchan committed
864
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
865
  double* out_result) {
866
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
867
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
868
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
869
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
870
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
871
872
873
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
874
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
875
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
876
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
877
878
    }
  }
wxchan's avatar
wxchan committed
879
  *out_len = nrow * num_preb_in_one_row;
880
  API_END();
Guolin Ke's avatar
Guolin Ke committed
881
}
882

883
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
884
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
885
  const char* filename) {
886
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
887
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
888
889
890
891
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

892
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
893
  int num_iteration,
wxchan's avatar
wxchan committed
894
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
895
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
896
  char* out_str) {
wxchan's avatar
wxchan committed
897
898
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
899
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
900
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
901
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
902
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
903
  }
904
  API_END();
Guolin Ke's avatar
Guolin Ke committed
905
}
906

907
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
908
909
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
910
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
911
912
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
913
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
914
915
916
  API_END();
}

917
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
918
919
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
920
  double val) {
Guolin Ke's avatar
Guolin Ke committed
921
922
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
923
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
924
925
926
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
927
// ---- start of some help functions
928
929
930

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
931
  if (data_type == C_API_DTYPE_FLOAT32) {
932
933
934
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
935
        std::vector<double> ret(num_col);
936
937
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
938
          ret[i] = static_cast<double>(*(tmp_ptr + i));
939
940
941
942
943
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
944
        std::vector<double> ret(num_col);
945
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
946
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
947
948
949
950
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
951
  } else if (data_type == C_API_DTYPE_FLOAT64) {
952
953
954
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
955
        std::vector<double> ret(num_col);
956
957
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
958
          ret[i] = static_cast<double>(*(tmp_ptr + i));
959
960
961
962
963
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
964
        std::vector<double> ret(num_col);
965
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
966
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
967
968
969
970
971
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
972
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
973
974
975
976
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
977
978
979
980
981
982
983
984
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
985
        }
Guolin Ke's avatar
Guolin Ke committed
986
987
988
      }
      return ret;
    };
989
  }
Guolin Ke's avatar
Guolin Ke committed
990
  return nullptr;
991
992
993
994
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
995
  if (data_type == C_API_DTYPE_FLOAT32) {
996
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
997
    if (indptr_type == C_API_DTYPE_INT32) {
998
999
1000
1001
1002
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1003
        for (int64_t i = start; i < end; ++i) {
1004
1005
1006
1007
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1008
    } else if (indptr_type == C_API_DTYPE_INT64) {
1009
1010
1011
1012
1013
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1014
        for (int64_t i = start; i < end; ++i) {
1015
1016
1017
1018
1019
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1020
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1021
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1022
    if (indptr_type == C_API_DTYPE_INT32) {
1023
1024
1025
1026
1027
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1028
        for (int64_t i = start; i < end; ++i) {
1029
1030
1031
1032
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1033
    } else if (indptr_type == C_API_DTYPE_INT64) {
1034
1035
1036
1037
1038
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1039
        for (int64_t i = start; i < end; ++i) {
1040
1041
1042
1043
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1044
1045
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1046
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1047
1048
}

Guolin Ke's avatar
Guolin Ke committed
1049
1050
1051
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1052
  if (data_type == C_API_DTYPE_FLOAT32) {
1053
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1054
    if (col_ptr_type == C_API_DTYPE_INT32) {
1055
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1056
1057
1058
1059
1060
1061
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1062
        }
Guolin Ke's avatar
Guolin Ke committed
1063
1064
1065
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1066
      };
Guolin Ke's avatar
Guolin Ke committed
1067
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1068
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1069
1070
1071
1072
1073
1074
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1075
        }
Guolin Ke's avatar
Guolin Ke committed
1076
1077
1078
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1079
      };
Guolin Ke's avatar
Guolin Ke committed
1080
    }
Guolin Ke's avatar
Guolin Ke committed
1081
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1082
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1083
    if (col_ptr_type == C_API_DTYPE_INT32) {
1084
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1085
1086
1087
1088
1089
1090
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1091
        }
Guolin Ke's avatar
Guolin Ke committed
1092
1093
1094
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1095
      };
Guolin Ke's avatar
Guolin Ke committed
1096
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1097
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1098
1099
1100
1101
1102
1103
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1104
        }
Guolin Ke's avatar
Guolin Ke committed
1105
1106
1107
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1108
      };
Guolin Ke's avatar
Guolin Ke committed
1109
1110
1111
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1112
1113
}

Guolin Ke's avatar
Guolin Ke committed
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1125
    }
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1143
    }
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1147
  }
Guolin Ke's avatar
Guolin Ke committed
1148
}