c_api.cpp 39.3 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
39
40
41
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
55
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
95
96
97
98
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
228
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
229

Guolin Ke's avatar
Guolin Ke committed
230
private:
231

wxchan's avatar
wxchan committed
232
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
233
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
234
235
236
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
237
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
238
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
239
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
240
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
241
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
242
243
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
244
245
246
};

}
Guolin Ke's avatar
Guolin Ke committed
247
248
249

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

282
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
283
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
284
285
}

286
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
287
  const char* parameters,
288
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
289
  DatasetHandle* out) {
290
  API_BEGIN();
wxchan's avatar
wxchan committed
291
292
293
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
294
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
295
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
296
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
297
  } else {
Guolin Ke's avatar
Guolin Ke committed
298
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
299
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
300
  }
301
  API_END();
Guolin Ke's avatar
Guolin Ke committed
302
303
}

304
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
305
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
310
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
311
  DatasetHandle* out) {
312
  API_BEGIN();
wxchan's avatar
wxchan committed
313
314
315
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
316
  std::unique_ptr<Dataset> ret;
317
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
318
319
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
320
321
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
322
    auto sample_indices = rand.Sample(nrow, sample_cnt);
323
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
324
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
325
      auto idx = sample_indices[i];
326
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
327
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
331
332
      }
    }
333
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
334
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
335
  } else {
336
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
337
    ret->CopyFeatureMapperFrom(
338
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
339
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
344
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
345
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
346
347
348
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
349
  *out = ret.release();
350
  API_END();
351
352
}

353
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
354
  int indptr_type,
355
356
  const int32_t* indices,
  const void* data,
357
358
359
360
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
361
  const char* parameters,
362
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
363
  DatasetHandle* out) {
364
  API_BEGIN();
wxchan's avatar
wxchan committed
365
366
367
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
368
  std::unique_ptr<Dataset> ret;
369
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
370
371
372
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
373
374
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
375
376
377
378
379
380
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
381
382
383
384
385
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
386
          }
387
388
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
389
390
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
391
392
393
        }
      }
    }
394
    CHECK(num_col >= static_cast<int>(sample_values.size()));
395
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
396
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
397
  } else {
398
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
399
    ret->CopyFeatureMapperFrom(
400
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
401
      io_config.is_enable_sparse);
402
403
404
405
406
407
408
409
410
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
411
  *out = ret.release();
412
  API_END();
413
414
}

415
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
416
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
417
418
  const int32_t* indices,
  const void* data,
419
420
421
422
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
423
  const char* parameters,
424
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
425
  DatasetHandle* out) {
426
  API_BEGIN();
wxchan's avatar
wxchan committed
427
428
429
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
430
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
431
432
433
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
434
435
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
436
437
438
439
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
440
441
442
443
444
445
446
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
447
    }
448
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
449
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
450
  } else {
451
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
452
    ret->CopyFeatureMapperFrom(
453
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
454
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
455
456
457
458
459
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
464
465
466
467
468
469
470
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
471
472
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
473
  *out = ret.release();
474
  API_END();
Guolin Ke's avatar
Guolin Ke committed
475
476
}

477
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
478
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
479
480
481
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
482
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
483
484
485
486
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
487
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
488
489
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
490
      num_used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
491
492
      io_config.is_enable_sparse,
      true));
wxchan's avatar
wxchan committed
493
494
495
496
497
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

498
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
499
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
500
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
501
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
502
503
504
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
505
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
510
511
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

512
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
513
514
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
515
  int* num_feature_names) {
516
517
518
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
519
520
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
521
522
523
524
525
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

526
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
527
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
528
  delete reinterpret_cast<Dataset*>(handle);
529
  API_END();
530
531
}

532
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
533
  const char* filename) {
534
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
535
536
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
537
  API_END();
538
539
}

540
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
541
542
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
543
  int num_element,
544
  int type) {
545
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
546
  auto dataset = reinterpret_cast<Dataset*>(handle);
547
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
548
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
549
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
550
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
551
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
552
553
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
554
  }
555
556
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
557
558
}

559
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
560
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
561
  int* out_len,
562
563
  const void** out_ptr,
  int* out_type) {
564
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
565
  auto dataset = reinterpret_cast<Dataset*>(handle);
566
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
567
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
568
    *out_type = C_API_DTYPE_FLOAT32;
569
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
570
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
571
    *out_type = C_API_DTYPE_INT32;
572
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
573
574
575
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
576
  }
577
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
578
  if (*out_ptr == nullptr) { *out_len = 0; }
579
  API_END();
580
581
}

582
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
583
  int* out) {
584
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
585
586
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
587
  API_END();
588
589
}

590
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
591
  int* out) {
592
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
593
594
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
595
  API_END();
Guolin Ke's avatar
Guolin Ke committed
596
}
597
598
599

// ---- start of booster

600
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
601
602
  const char* parameters,
  BoosterHandle* out) {
603
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
604
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
605
606
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
607
  API_END();
608
609
}

610
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
611
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
612
  int* out_num_iterations,
613
  BoosterHandle* out) {
614
  API_BEGIN();
wxchan's avatar
wxchan committed
615
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
616
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
617
  *out = ret.release();
618
  API_END();
619
620
}

621
622
623
624
625
626
627
628
629
630
631
632
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

633
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
634
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
635
  delete reinterpret_cast<Booster*>(handle);
636
  API_END();
637
638
}

639
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
640
641
642
643
644
645
646
647
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

648
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
649
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
650
651
652
653
654
655
656
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

657
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
658
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
659
660
661
662
663
664
665
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

666
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
667
668
669
670
671
672
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

673
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
674
675
676
677
678
679
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

680
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
681
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
682
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
683
684
685
686
687
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
688
  API_END();
689
690
}

691
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
692
693
694
  const float* grad,
  const float* hess,
  int* is_finished) {
695
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
696
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
697
698
699
700
701
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
702
  API_END();
703
704
}

705
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
706
707
708
709
710
711
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

712
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
713
714
715
716
717
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
718

719
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
720
721
722
723
724
725
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

726
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
727
728
729
730
731
732
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

733
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
734
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
735
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
736
  double* out_results) {
737
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
738
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
739
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
740
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
741
  *out_len = static_cast<int>(result_buf.size());
742
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
743
    (out_results)[i] = static_cast<double>(result_buf[i]);
744
  }
745
  API_END();
746
747
}

748
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
749
750
751
752
753
754
755
756
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

757
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
758
  int data_idx,
759
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
760
  double* out_result) {
761
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
762
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
763
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
764
  API_END();
Guolin Ke's avatar
Guolin Ke committed
765
766
}

767
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
768
  const char* data_filename,
wxchan's avatar
wxchan committed
769
770
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
771
  int num_iteration,
772
  const char* result_filename) {
773
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
774
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
775
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
776
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
777
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
778
  API_END();
779
780
}

781
782
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
783
784
785
786
787
788
789
790
791
792
793
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

794
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
795
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
796
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
797
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
798
799
800
801
802
803
804
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

805
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
806
807
  const void* indptr,
  int indptr_type,
808
809
  const int32_t* indices,
  const void* data,
810
811
812
813
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
814
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
815
  int num_iteration,
wxchan's avatar
wxchan committed
816
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
817
  double* out_result) {
818
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
819
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
820
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
821
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
822
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
823
824
825
826
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
827
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
828
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
829
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
830
831
    }
  }
wxchan's avatar
wxchan committed
832
  *out_len = nrow * num_preb_in_one_row;
833
  API_END();
Guolin Ke's avatar
Guolin Ke committed
834
}
835

836
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
837
838
839
840
841
842
843
844
845
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
846
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
847
848
849
850
851
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
852
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

881
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
882
  const void* data,
883
  int data_type,
884
885
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
886
  int is_row_major,
887
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
888
  int num_iteration,
wxchan's avatar
wxchan committed
889
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
890
  double* out_result) {
891
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
892
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
893
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
894
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
895
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
896
897
898
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
899
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
900
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
901
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
902
903
    }
  }
wxchan's avatar
wxchan committed
904
  *out_len = nrow * num_preb_in_one_row;
905
  API_END();
Guolin Ke's avatar
Guolin Ke committed
906
}
907

908
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
909
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
910
  const char* filename) {
911
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
912
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
913
914
915
916
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
  int num_iteration,
  int buffer_len,
  int* out_len,
  char* out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

932
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
933
  int num_iteration,
wxchan's avatar
wxchan committed
934
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
935
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
936
  char* out_str) {
wxchan's avatar
wxchan committed
937
938
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
939
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
940
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
941
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
942
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
943
  }
944
  API_END();
Guolin Ke's avatar
Guolin Ke committed
945
}
946

947
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
948
949
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
950
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
951
952
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
953
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
954
955
956
  API_END();
}

957
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
958
959
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
960
  double val) {
Guolin Ke's avatar
Guolin Ke committed
961
962
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
963
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
964
965
966
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
967
// ---- start of some help functions
968
969
970

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
971
  if (data_type == C_API_DTYPE_FLOAT32) {
972
973
974
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
975
        std::vector<double> ret(num_col);
976
977
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
978
          ret[i] = static_cast<double>(*(tmp_ptr + i));
979
980
981
982
983
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
984
        std::vector<double> ret(num_col);
985
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
986
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
987
988
989
990
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
991
  } else if (data_type == C_API_DTYPE_FLOAT64) {
992
993
994
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
995
        std::vector<double> ret(num_col);
996
997
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
998
          ret[i] = static_cast<double>(*(tmp_ptr + i));
999
1000
1001
1002
1003
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1004
        std::vector<double> ret(num_col);
1005
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1006
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1007
1008
1009
1010
1011
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1012
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1013
1014
1015
1016
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1017
1018
1019
1020
1021
1022
1023
1024
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1025
        }
Guolin Ke's avatar
Guolin Ke committed
1026
1027
1028
      }
      return ret;
    };
1029
  }
Guolin Ke's avatar
Guolin Ke committed
1030
  return nullptr;
1031
1032
1033
1034
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1035
  if (data_type == C_API_DTYPE_FLOAT32) {
1036
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1037
    if (indptr_type == C_API_DTYPE_INT32) {
1038
1039
1040
1041
1042
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1043
        for (int64_t i = start; i < end; ++i) {
1044
1045
1046
1047
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1048
    } else if (indptr_type == C_API_DTYPE_INT64) {
1049
1050
1051
1052
1053
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1054
        for (int64_t i = start; i < end; ++i) {
1055
1056
1057
1058
1059
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1060
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1061
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1062
    if (indptr_type == C_API_DTYPE_INT32) {
1063
1064
1065
1066
1067
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1068
        for (int64_t i = start; i < end; ++i) {
1069
1070
1071
1072
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1073
    } else if (indptr_type == C_API_DTYPE_INT64) {
1074
1075
1076
1077
1078
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1079
        for (int64_t i = start; i < end; ++i) {
1080
1081
1082
1083
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1084
1085
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1086
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1087
1088
}

Guolin Ke's avatar
Guolin Ke committed
1089
1090
1091
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1092
  if (data_type == C_API_DTYPE_FLOAT32) {
1093
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1094
    if (col_ptr_type == C_API_DTYPE_INT32) {
1095
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1096
1097
1098
1099
1100
1101
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1102
        }
Guolin Ke's avatar
Guolin Ke committed
1103
1104
1105
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1106
      };
Guolin Ke's avatar
Guolin Ke committed
1107
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1108
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1109
1110
1111
1112
1113
1114
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1115
        }
Guolin Ke's avatar
Guolin Ke committed
1116
1117
1118
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1119
      };
Guolin Ke's avatar
Guolin Ke committed
1120
    }
Guolin Ke's avatar
Guolin Ke committed
1121
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1122
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1123
    if (col_ptr_type == C_API_DTYPE_INT32) {
1124
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1125
1126
1127
1128
1129
1130
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1131
        }
Guolin Ke's avatar
Guolin Ke committed
1132
1133
1134
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1135
      };
Guolin Ke's avatar
Guolin Ke committed
1136
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1137
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1138
1139
1140
1141
1142
1143
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1144
        }
Guolin Ke's avatar
Guolin Ke committed
1145
1146
1147
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1148
      };
Guolin Ke's avatar
Guolin Ke committed
1149
1150
1151
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1152
1153
}

Guolin Ke's avatar
Guolin Ke committed
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1165
    }
Guolin Ke's avatar
Guolin Ke committed
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1183
    }
Guolin Ke's avatar
Guolin Ke committed
1184
1185
1186
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1187
  }
Guolin Ke's avatar
Guolin Ke committed
1188
}