c_api.cpp 39.9 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
39
40
41
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
55
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
95
96
97
98
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
228
229
230
231
232
233
234
235
236
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
237
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
238

Guolin Ke's avatar
Guolin Ke committed
239
private:
240

wxchan's avatar
wxchan committed
241
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
242
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
243
244
245
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
246
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
247
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
248
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
249
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
251
252
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
};

}
Guolin Ke's avatar
Guolin Ke committed
256
257
258

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

291
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
292
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
293
294
}

295
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
296
  const char* parameters,
297
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
298
  DatasetHandle* out) {
299
  API_BEGIN();
wxchan's avatar
wxchan committed
300
301
302
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
303
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
304
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
305
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
306
  } else {
Guolin Ke's avatar
Guolin Ke committed
307
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
308
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
309
  }
310
  API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
314
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
319
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
320
  DatasetHandle* out) {
321
  API_BEGIN();
wxchan's avatar
wxchan committed
322
323
324
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
325
  std::unique_ptr<Dataset> ret;
326
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
327
328
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
329
330
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
331
    auto sample_indices = rand.Sample(nrow, sample_cnt);
332
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
333
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
334
      auto idx = sample_indices[i];
335
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
336
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
337
338
339
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
340
341
      }
    }
342
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
343
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
344
  } else {
345
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
346
    ret->CopyFeatureMapperFrom(
347
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
353
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
354
355
356
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
357
  *out = ret.release();
358
  API_END();
359
360
}

361
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
362
  int indptr_type,
363
364
  const int32_t* indices,
  const void* data,
365
366
367
368
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
369
  const char* parameters,
370
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
371
  DatasetHandle* out) {
372
  API_BEGIN();
wxchan's avatar
wxchan committed
373
374
375
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
376
  std::unique_ptr<Dataset> ret;
377
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
378
379
380
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
381
382
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
383
384
385
386
387
388
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
389
390
391
392
393
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
394
          }
395
396
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
397
398
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
399
400
401
        }
      }
    }
402
    CHECK(num_col >= static_cast<int>(sample_values.size()));
403
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
404
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
405
  } else {
406
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
407
    ret->CopyFeatureMapperFrom(
408
      reinterpret_cast<const Dataset*>(reference));
409
410
411
412
413
414
415
416
417
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
418
  *out = ret.release();
419
  API_END();
420
421
}

422
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
423
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
424
425
  const int32_t* indices,
  const void* data,
426
427
428
429
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
430
  const char* parameters,
431
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
432
  DatasetHandle* out) {
433
  API_BEGIN();
wxchan's avatar
wxchan committed
434
435
436
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
437
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
438
439
440
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
441
442
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
443
444
445
446
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
447
448
449
450
451
452
453
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
454
    }
455
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
456
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
457
  } else {
458
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
459
    ret->CopyFeatureMapperFrom(
460
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
461
462
463
464
465
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
466
467
468
469
470
471
472
473
474
475
476
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
477
478
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
479
  *out = ret.release();
480
  API_END();
Guolin Ke's avatar
Guolin Ke committed
481
482
}

483
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
484
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
485
486
487
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
488
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
489
490
491
492
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
493
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
494
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
495
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
496
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
497
498
499
500
  *out = ret.release();
  API_END();
}

501
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
502
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
503
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
504
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
505
506
507
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
508
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
509
510
511
512
513
514
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

515
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
516
517
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
518
  int* num_feature_names) {
519
520
521
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
522
523
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
524
525
526
527
528
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

529
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
530
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
531
  delete reinterpret_cast<Dataset*>(handle);
532
  API_END();
533
534
}

535
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
536
  const char* filename) {
537
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
538
539
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
540
  API_END();
541
542
}

543
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
544
545
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
546
  int num_element,
547
  int type) {
548
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
549
  auto dataset = reinterpret_cast<Dataset*>(handle);
550
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
551
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
552
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
553
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
554
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
555
556
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
557
  }
558
559
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
560
561
}

562
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
563
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
564
  int* out_len,
565
566
  const void** out_ptr,
  int* out_type) {
567
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
568
  auto dataset = reinterpret_cast<Dataset*>(handle);
569
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
570
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
571
    *out_type = C_API_DTYPE_FLOAT32;
572
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
573
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
574
    *out_type = C_API_DTYPE_INT32;
575
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
576
577
578
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
579
  }
580
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
581
  if (*out_ptr == nullptr) { *out_len = 0; }
582
  API_END();
583
584
}

585
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
586
  int* out) {
587
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
588
589
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
590
  API_END();
591
592
}

593
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
594
  int* out) {
595
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
596
597
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
598
  API_END();
Guolin Ke's avatar
Guolin Ke committed
599
}
600
601
602

// ---- start of booster

603
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
604
605
  const char* parameters,
  BoosterHandle* out) {
606
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
607
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
608
609
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
610
  API_END();
611
612
}

613
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
614
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
615
  int* out_num_iterations,
616
  BoosterHandle* out) {
617
  API_BEGIN();
wxchan's avatar
wxchan committed
618
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
619
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
620
  *out = ret.release();
621
  API_END();
622
623
}

624
625
626
627
628
629
630
631
632
633
634
635
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

636
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
637
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
638
  delete reinterpret_cast<Booster*>(handle);
639
  API_END();
640
641
}

642
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
643
644
645
646
647
648
649
650
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

651
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
652
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
653
654
655
656
657
658
659
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

660
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
661
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
662
663
664
665
666
667
668
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

669
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
670
671
672
673
674
675
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

676
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
677
678
679
680
681
682
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

683
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
684
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
685
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
686
687
688
689
690
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
691
  API_END();
692
693
}

694
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
695
696
697
  const float* grad,
  const float* hess,
  int* is_finished) {
698
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
699
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
700
701
702
703
704
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
705
  API_END();
706
707
}

708
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
709
710
711
712
713
714
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

715
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
716
717
718
719
720
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
721

722
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
723
724
725
726
727
728
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

729
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
730
731
732
733
734
735
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

wxchan's avatar
wxchan committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

750
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
751
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
752
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
753
  double* out_results) {
754
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
755
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
756
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
757
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
758
  *out_len = static_cast<int>(result_buf.size());
759
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
760
    (out_results)[i] = static_cast<double>(result_buf[i]);
761
  }
762
  API_END();
763
764
}

765
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
766
767
768
769
770
771
772
773
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

774
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
775
  int data_idx,
776
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
777
  double* out_result) {
778
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
779
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
780
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
781
  API_END();
Guolin Ke's avatar
Guolin Ke committed
782
783
}

784
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
785
  const char* data_filename,
wxchan's avatar
wxchan committed
786
787
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
788
  int num_iteration,
789
  const char* result_filename) {
790
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
791
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
792
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
793
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
794
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
795
  API_END();
796
797
}

798
799
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
800
801
802
803
804
805
806
807
808
809
810
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

811
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
812
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
813
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
814
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
815
816
817
818
819
820
821
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

822
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
823
824
  const void* indptr,
  int indptr_type,
825
826
  const int32_t* indices,
  const void* data,
827
828
829
830
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
831
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
832
  int num_iteration,
wxchan's avatar
wxchan committed
833
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
834
  double* out_result) {
835
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
836
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
837
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
838
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
839
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
840
841
842
843
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
844
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
845
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
846
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
847
848
    }
  }
wxchan's avatar
wxchan committed
849
  *out_len = nrow * num_preb_in_one_row;
850
  API_END();
Guolin Ke's avatar
Guolin Ke committed
851
}
852

853
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
854
855
856
857
858
859
860
861
862
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
863
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
864
865
866
867
868
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
869
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

898
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
899
  const void* data,
900
  int data_type,
901
902
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
903
  int is_row_major,
904
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
905
  int num_iteration,
wxchan's avatar
wxchan committed
906
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
907
  double* out_result) {
908
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
909
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
910
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
911
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
912
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
913
914
915
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
916
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
917
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
918
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
919
920
    }
  }
wxchan's avatar
wxchan committed
921
  *out_len = nrow * num_preb_in_one_row;
922
  API_END();
Guolin Ke's avatar
Guolin Ke committed
923
}
924

925
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
926
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
927
  const char* filename) {
928
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
929
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
930
931
932
933
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
  int num_iteration,
  int buffer_len,
  int* out_len,
  char* out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

949
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
950
  int num_iteration,
wxchan's avatar
wxchan committed
951
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
952
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
953
  char* out_str) {
wxchan's avatar
wxchan committed
954
955
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
956
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
957
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
958
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
959
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
960
  }
961
  API_END();
Guolin Ke's avatar
Guolin Ke committed
962
}
963

964
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
965
966
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
967
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
968
969
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
970
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
971
972
973
  API_END();
}

974
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
975
976
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
977
  double val) {
Guolin Ke's avatar
Guolin Ke committed
978
979
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
980
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
981
982
983
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
984
// ---- start of some help functions
985
986
987

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
988
  if (data_type == C_API_DTYPE_FLOAT32) {
989
990
991
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
992
        std::vector<double> ret(num_col);
993
994
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
995
          ret[i] = static_cast<double>(*(tmp_ptr + i));
996
997
998
999
1000
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1001
        std::vector<double> ret(num_col);
1002
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1003
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1004
1005
1006
1007
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1008
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1009
1010
1011
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1012
        std::vector<double> ret(num_col);
1013
1014
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1015
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1016
1017
1018
1019
1020
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1021
        std::vector<double> ret(num_col);
1022
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1023
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1024
1025
1026
1027
1028
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1029
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1030
1031
1032
1033
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1034
1035
1036
1037
1038
1039
1040
1041
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1042
        }
Guolin Ke's avatar
Guolin Ke committed
1043
1044
1045
      }
      return ret;
    };
1046
  }
Guolin Ke's avatar
Guolin Ke committed
1047
  return nullptr;
1048
1049
1050
1051
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1052
  if (data_type == C_API_DTYPE_FLOAT32) {
1053
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1054
    if (indptr_type == C_API_DTYPE_INT32) {
1055
1056
1057
1058
1059
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1060
        for (int64_t i = start; i < end; ++i) {
1061
1062
1063
1064
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1065
    } else if (indptr_type == C_API_DTYPE_INT64) {
1066
1067
1068
1069
1070
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1071
        for (int64_t i = start; i < end; ++i) {
1072
1073
1074
1075
1076
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1077
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1078
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1079
    if (indptr_type == C_API_DTYPE_INT32) {
1080
1081
1082
1083
1084
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1085
        for (int64_t i = start; i < end; ++i) {
1086
1087
1088
1089
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1090
    } else if (indptr_type == C_API_DTYPE_INT64) {
1091
1092
1093
1094
1095
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1096
        for (int64_t i = start; i < end; ++i) {
1097
1098
1099
1100
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1101
1102
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1103
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1104
1105
}

Guolin Ke's avatar
Guolin Ke committed
1106
1107
1108
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1109
  if (data_type == C_API_DTYPE_FLOAT32) {
1110
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1111
    if (col_ptr_type == C_API_DTYPE_INT32) {
1112
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1113
1114
1115
1116
1117
1118
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1119
        }
Guolin Ke's avatar
Guolin Ke committed
1120
1121
1122
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1123
      };
Guolin Ke's avatar
Guolin Ke committed
1124
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1125
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
1129
1130
1131
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1132
        }
Guolin Ke's avatar
Guolin Ke committed
1133
1134
1135
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1136
      };
Guolin Ke's avatar
Guolin Ke committed
1137
    }
Guolin Ke's avatar
Guolin Ke committed
1138
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1139
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1140
    if (col_ptr_type == C_API_DTYPE_INT32) {
1141
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1142
1143
1144
1145
1146
1147
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1148
        }
Guolin Ke's avatar
Guolin Ke committed
1149
1150
1151
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1152
      };
Guolin Ke's avatar
Guolin Ke committed
1153
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1154
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1155
1156
1157
1158
1159
1160
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1161
        }
Guolin Ke's avatar
Guolin Ke committed
1162
1163
1164
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1165
      };
Guolin Ke's avatar
Guolin Ke committed
1166
1167
1168
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1169
1170
}

Guolin Ke's avatar
Guolin Ke committed
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1182
    }
Guolin Ke's avatar
Guolin Ke committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1200
    }
Guolin Ke's avatar
Guolin Ke committed
1201
1202
1203
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1204
  }
Guolin Ke's avatar
Guolin Ke committed
1205
}