c_api.cpp 39.2 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
39
40
41
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
55
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
95
96
97
98
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
228
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
229

Guolin Ke's avatar
Guolin Ke committed
230
private:
231

wxchan's avatar
wxchan committed
232
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
233
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
234
235
236
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
237
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
238
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
239
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
240
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
241
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
242
243
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
244
245
246
};

}
Guolin Ke's avatar
Guolin Ke committed
247
248
249

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

282
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
283
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
284
285
}

286
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
287
  const char* parameters,
288
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
289
  DatasetHandle* out) {
290
  API_BEGIN();
wxchan's avatar
wxchan committed
291
292
293
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
294
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
295
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
296
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
297
  } else {
Guolin Ke's avatar
Guolin Ke committed
298
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
299
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
300
  }
301
  API_END();
Guolin Ke's avatar
Guolin Ke committed
302
303
}

304
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
305
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
310
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
311
  DatasetHandle* out) {
312
  API_BEGIN();
wxchan's avatar
wxchan committed
313
314
315
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
316
  std::unique_ptr<Dataset> ret;
317
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
318
319
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
320
321
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
322
    auto sample_indices = rand.Sample(nrow, sample_cnt);
323
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
324
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
325
      auto idx = sample_indices[i];
326
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
327
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
331
332
      }
    }
333
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
334
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
335
  } else {
336
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
337
    ret->CopyFeatureMapperFrom(
338
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
339
340
341
342
343
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
344
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
345
346
347
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
348
  *out = ret.release();
349
  API_END();
350
351
}

352
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
353
  int indptr_type,
354
355
  const int32_t* indices,
  const void* data,
356
357
358
359
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
360
  const char* parameters,
361
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
362
  DatasetHandle* out) {
363
  API_BEGIN();
wxchan's avatar
wxchan committed
364
365
366
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
367
  std::unique_ptr<Dataset> ret;
368
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
369
370
371
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
372
373
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
374
375
376
377
378
379
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
380
381
382
383
384
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
385
          }
386
387
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
388
389
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
390
391
392
        }
      }
    }
393
    CHECK(num_col >= static_cast<int>(sample_values.size()));
394
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
395
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
396
  } else {
397
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
398
    ret->CopyFeatureMapperFrom(
399
      reinterpret_cast<const Dataset*>(reference));
400
401
402
403
404
405
406
407
408
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
409
  *out = ret.release();
410
  API_END();
411
412
}

413
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
414
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
415
416
  const int32_t* indices,
  const void* data,
417
418
419
420
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
421
  const char* parameters,
422
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
423
  DatasetHandle* out) {
424
  API_BEGIN();
wxchan's avatar
wxchan committed
425
426
427
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
428
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
429
430
431
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
432
433
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
434
435
436
437
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
438
439
440
441
442
443
444
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
445
    }
446
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
447
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
448
  } else {
449
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
450
    ret->CopyFeatureMapperFrom(
451
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
457
458
459
460
461
462
463
464
465
466
467
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
468
469
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
470
  *out = ret.release();
471
  API_END();
Guolin Ke's avatar
Guolin Ke committed
472
473
}

474
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
475
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
476
477
478
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
479
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
480
481
482
483
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
484
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
485
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
486
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
487
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
488
489
490
491
  *out = ret.release();
  API_END();
}

492
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
493
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
494
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
495
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
496
497
498
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
499
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
500
501
502
503
504
505
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

506
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
507
508
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
509
  int* num_feature_names) {
510
511
512
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
513
514
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
515
516
517
518
519
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

520
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
521
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
522
  delete reinterpret_cast<Dataset*>(handle);
523
  API_END();
524
525
}

526
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
527
  const char* filename) {
528
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
529
530
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
531
  API_END();
532
533
}

534
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
535
536
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
537
  int num_element,
538
  int type) {
539
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
540
  auto dataset = reinterpret_cast<Dataset*>(handle);
541
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
542
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
543
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
544
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
545
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
546
547
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
548
  }
549
550
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
551
552
}

553
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
554
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
555
  int* out_len,
556
557
  const void** out_ptr,
  int* out_type) {
558
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
559
  auto dataset = reinterpret_cast<Dataset*>(handle);
560
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
561
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
562
    *out_type = C_API_DTYPE_FLOAT32;
563
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
564
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
565
    *out_type = C_API_DTYPE_INT32;
566
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
567
568
569
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
570
  }
571
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
572
  if (*out_ptr == nullptr) { *out_len = 0; }
573
  API_END();
574
575
}

576
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
577
  int* out) {
578
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
579
580
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
581
  API_END();
582
583
}

584
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
585
  int* out) {
586
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
587
588
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
589
  API_END();
Guolin Ke's avatar
Guolin Ke committed
590
}
591
592
593

// ---- start of booster

594
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
595
596
  const char* parameters,
  BoosterHandle* out) {
597
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
598
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
599
600
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
601
  API_END();
602
603
}

604
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
605
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
606
  int* out_num_iterations,
607
  BoosterHandle* out) {
608
  API_BEGIN();
wxchan's avatar
wxchan committed
609
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
610
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
611
  *out = ret.release();
612
  API_END();
613
614
}

615
616
617
618
619
620
621
622
623
624
625
626
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

627
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
628
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
629
  delete reinterpret_cast<Booster*>(handle);
630
  API_END();
631
632
}

633
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
634
635
636
637
638
639
640
641
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

642
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
643
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
644
645
646
647
648
649
650
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

651
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
652
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
653
654
655
656
657
658
659
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

660
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
661
662
663
664
665
666
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

667
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
668
669
670
671
672
673
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

674
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
675
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
676
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
677
678
679
680
681
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
682
  API_END();
683
684
}

685
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
686
687
688
  const float* grad,
  const float* hess,
  int* is_finished) {
689
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
690
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
691
692
693
694
695
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
696
  API_END();
697
698
}

699
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
700
701
702
703
704
705
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

706
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
707
708
709
710
711
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
712

713
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
714
715
716
717
718
719
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

720
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
721
722
723
724
725
726
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

727
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
728
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
729
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
730
  double* out_results) {
731
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
732
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
733
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
734
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
735
  *out_len = static_cast<int>(result_buf.size());
736
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
737
    (out_results)[i] = static_cast<double>(result_buf[i]);
738
  }
739
  API_END();
740
741
}

742
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
743
744
745
746
747
748
749
750
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

751
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
752
  int data_idx,
753
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
754
  double* out_result) {
755
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
756
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
757
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
758
  API_END();
Guolin Ke's avatar
Guolin Ke committed
759
760
}

761
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
762
  const char* data_filename,
wxchan's avatar
wxchan committed
763
764
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
765
  int num_iteration,
766
  const char* result_filename) {
767
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
768
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
769
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
770
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
771
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
772
  API_END();
773
774
}

775
776
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
777
778
779
780
781
782
783
784
785
786
787
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

788
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
789
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
790
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
791
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
792
793
794
795
796
797
798
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

799
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
800
801
  const void* indptr,
  int indptr_type,
802
803
  const int32_t* indices,
  const void* data,
804
805
806
807
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
808
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
809
  int num_iteration,
wxchan's avatar
wxchan committed
810
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
811
  double* out_result) {
812
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
813
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
814
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
815
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
816
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
817
818
819
820
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
821
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
822
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
823
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
824
825
    }
  }
wxchan's avatar
wxchan committed
826
  *out_len = nrow * num_preb_in_one_row;
827
  API_END();
Guolin Ke's avatar
Guolin Ke committed
828
}
829

830
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
831
832
833
834
835
836
837
838
839
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
840
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
841
842
843
844
845
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
846
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

875
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
876
  const void* data,
877
  int data_type,
878
879
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
880
  int is_row_major,
881
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
882
  int num_iteration,
wxchan's avatar
wxchan committed
883
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
884
  double* out_result) {
885
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
886
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
887
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
888
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
889
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
890
891
892
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
893
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
894
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
895
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
896
897
    }
  }
wxchan's avatar
wxchan committed
898
  *out_len = nrow * num_preb_in_one_row;
899
  API_END();
Guolin Ke's avatar
Guolin Ke committed
900
}
901

902
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
903
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
904
  const char* filename) {
905
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
906
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
907
908
909
910
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
  int num_iteration,
  int buffer_len,
  int* out_len,
  char* out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

926
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
927
  int num_iteration,
wxchan's avatar
wxchan committed
928
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
929
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
930
  char* out_str) {
wxchan's avatar
wxchan committed
931
932
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
933
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
934
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
935
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
936
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
937
  }
938
  API_END();
Guolin Ke's avatar
Guolin Ke committed
939
}
940

941
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
942
943
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
944
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
945
946
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
947
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
948
949
950
  API_END();
}

951
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
952
953
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
954
  double val) {
Guolin Ke's avatar
Guolin Ke committed
955
956
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
957
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
958
959
960
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
961
// ---- start of some help functions
962
963
964

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
965
  if (data_type == C_API_DTYPE_FLOAT32) {
966
967
968
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
969
        std::vector<double> ret(num_col);
970
971
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
972
          ret[i] = static_cast<double>(*(tmp_ptr + i));
973
974
975
976
977
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
978
        std::vector<double> ret(num_col);
979
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
980
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
981
982
983
984
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
985
  } else if (data_type == C_API_DTYPE_FLOAT64) {
986
987
988
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
989
        std::vector<double> ret(num_col);
990
991
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
992
          ret[i] = static_cast<double>(*(tmp_ptr + i));
993
994
995
996
997
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
998
        std::vector<double> ret(num_col);
999
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1000
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1001
1002
1003
1004
1005
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1006
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1007
1008
1009
1010
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1011
1012
1013
1014
1015
1016
1017
1018
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1019
        }
Guolin Ke's avatar
Guolin Ke committed
1020
1021
1022
      }
      return ret;
    };
1023
  }
Guolin Ke's avatar
Guolin Ke committed
1024
  return nullptr;
1025
1026
1027
1028
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1029
  if (data_type == C_API_DTYPE_FLOAT32) {
1030
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1031
    if (indptr_type == C_API_DTYPE_INT32) {
1032
1033
1034
1035
1036
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1037
        for (int64_t i = start; i < end; ++i) {
1038
1039
1040
1041
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1042
    } else if (indptr_type == C_API_DTYPE_INT64) {
1043
1044
1045
1046
1047
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1048
        for (int64_t i = start; i < end; ++i) {
1049
1050
1051
1052
1053
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1054
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1055
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1056
    if (indptr_type == C_API_DTYPE_INT32) {
1057
1058
1059
1060
1061
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1062
        for (int64_t i = start; i < end; ++i) {
1063
1064
1065
1066
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1067
    } else if (indptr_type == C_API_DTYPE_INT64) {
1068
1069
1070
1071
1072
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1073
        for (int64_t i = start; i < end; ++i) {
1074
1075
1076
1077
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1078
1079
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1080
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1081
1082
}

Guolin Ke's avatar
Guolin Ke committed
1083
1084
1085
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1086
  if (data_type == C_API_DTYPE_FLOAT32) {
1087
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1088
    if (col_ptr_type == C_API_DTYPE_INT32) {
1089
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1090
1091
1092
1093
1094
1095
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1096
        }
Guolin Ke's avatar
Guolin Ke committed
1097
1098
1099
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1100
      };
Guolin Ke's avatar
Guolin Ke committed
1101
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1102
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1103
1104
1105
1106
1107
1108
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1109
        }
Guolin Ke's avatar
Guolin Ke committed
1110
1111
1112
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1113
      };
Guolin Ke's avatar
Guolin Ke committed
1114
    }
Guolin Ke's avatar
Guolin Ke committed
1115
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1116
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1117
    if (col_ptr_type == C_API_DTYPE_INT32) {
1118
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1119
1120
1121
1122
1123
1124
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1125
        }
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1129
      };
Guolin Ke's avatar
Guolin Ke committed
1130
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1131
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1132
1133
1134
1135
1136
1137
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1138
        }
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1142
      };
Guolin Ke's avatar
Guolin Ke committed
1143
1144
1145
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1146
1147
}

Guolin Ke's avatar
Guolin Ke committed
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1159
    }
Guolin Ke's avatar
Guolin Ke committed
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1177
    }
Guolin Ke's avatar
Guolin Ke committed
1178
1179
1180
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1181
  }
Guolin Ke's avatar
Guolin Ke committed
1182
}