c_api.cpp 39.3 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
39
40
41
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
55
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
95
96
97
98
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
228
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
229

Guolin Ke's avatar
Guolin Ke committed
230
private:
231

wxchan's avatar
wxchan committed
232
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
233
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
234
235
236
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
237
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
238
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
239
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
240
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
241
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
242
243
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
244
245
246
};

}
Guolin Ke's avatar
Guolin Ke committed
247
248
249

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

282
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
283
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
284
285
}

286
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
287
  const char* parameters,
288
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
289
  DatasetHandle* out) {
290
  API_BEGIN();
wxchan's avatar
wxchan committed
291
292
293
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
294
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
295
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
296
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
297
  } else {
Guolin Ke's avatar
Guolin Ke committed
298
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
299
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
300
  }
301
  API_END();
Guolin Ke's avatar
Guolin Ke committed
302
303
}

304
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
305
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
310
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
311
  DatasetHandle* out) {
312
  API_BEGIN();
wxchan's avatar
wxchan committed
313
314
315
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
316
  std::unique_ptr<Dataset> ret;
317
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
318
319
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
320
321
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
322
    auto sample_indices = rand.Sample(nrow, sample_cnt);
323
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
324
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
325
      auto idx = sample_indices[i];
326
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
327
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
331
332
      }
    }
333
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
334
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
335
  } else {
336
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
337
    ret->CopyFeatureMapperFrom(
338
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
339
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
344
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
345
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
346
347
348
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
349
  *out = ret.release();
350
  API_END();
351
352
}

353
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
354
  int indptr_type,
355
356
  const int32_t* indices,
  const void* data,
357
358
359
360
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
361
  const char* parameters,
362
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
363
  DatasetHandle* out) {
364
  API_BEGIN();
wxchan's avatar
wxchan committed
365
366
367
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
368
  std::unique_ptr<Dataset> ret;
369
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
370
371
372
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
373
374
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
375
376
377
378
379
380
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
381
382
383
384
385
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
386
          }
387
388
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
389
390
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
391
392
393
        }
      }
    }
394
    CHECK(num_col >= static_cast<int>(sample_values.size()));
395
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
396
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
397
  } else {
398
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
399
    ret->CopyFeatureMapperFrom(
400
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
401
      io_config.is_enable_sparse);
402
403
404
405
406
407
408
409
410
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
411
  *out = ret.release();
412
  API_END();
413
414
}

415
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
416
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
417
418
  const int32_t* indices,
  const void* data,
419
420
421
422
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
423
  const char* parameters,
424
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
425
  DatasetHandle* out) {
426
  API_BEGIN();
wxchan's avatar
wxchan committed
427
428
429
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
430
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
431
432
433
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
434
435
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
436
437
438
439
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
440
441
442
443
444
445
446
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
447
    }
448
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
449
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
450
  } else {
451
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
452
    ret->CopyFeatureMapperFrom(
453
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
454
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
455
456
457
458
459
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
464
465
466
467
468
469
470
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
471
472
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
473
  *out = ret.release();
474
  API_END();
Guolin Ke's avatar
Guolin Ke committed
475
476
}

477
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
478
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
479
480
481
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
482
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
483
484
485
486
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
487
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
488
489
490
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
  ret->CopyFeatureMapperFrom(full_dataset, io_config.is_enable_sparse);
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
491
492
493
494
  *out = ret.release();
  API_END();
}

495
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
496
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
497
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
498
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
499
500
501
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
502
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
507
508
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

509
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
510
511
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
512
  int* num_feature_names) {
513
514
515
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
516
517
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
518
519
520
521
522
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

523
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
524
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
525
  delete reinterpret_cast<Dataset*>(handle);
526
  API_END();
527
528
}

529
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
530
  const char* filename) {
531
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
532
533
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
534
  API_END();
535
536
}

537
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
538
539
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
540
  int num_element,
541
  int type) {
542
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
543
  auto dataset = reinterpret_cast<Dataset*>(handle);
544
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
545
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
546
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
547
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
548
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
549
550
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
551
  }
552
553
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
554
555
}

556
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
557
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
558
  int* out_len,
559
560
  const void** out_ptr,
  int* out_type) {
561
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
562
  auto dataset = reinterpret_cast<Dataset*>(handle);
563
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
564
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
565
    *out_type = C_API_DTYPE_FLOAT32;
566
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
567
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
568
    *out_type = C_API_DTYPE_INT32;
569
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
570
571
572
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
573
  }
574
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
575
  if (*out_ptr == nullptr) { *out_len = 0; }
576
  API_END();
577
578
}

579
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
580
  int* out) {
581
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
582
583
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
584
  API_END();
585
586
}

587
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
588
  int* out) {
589
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
590
591
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
592
  API_END();
Guolin Ke's avatar
Guolin Ke committed
593
}
594
595
596

// ---- start of booster

597
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
598
599
  const char* parameters,
  BoosterHandle* out) {
600
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
601
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
602
603
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
604
  API_END();
605
606
}

607
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
608
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
609
  int* out_num_iterations,
610
  BoosterHandle* out) {
611
  API_BEGIN();
wxchan's avatar
wxchan committed
612
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
613
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
614
  *out = ret.release();
615
  API_END();
616
617
}

618
619
620
621
622
623
624
625
626
627
628
629
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

630
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
631
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
632
  delete reinterpret_cast<Booster*>(handle);
633
  API_END();
634
635
}

636
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
637
638
639
640
641
642
643
644
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

645
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
646
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
647
648
649
650
651
652
653
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

654
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
655
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
656
657
658
659
660
661
662
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

663
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
664
665
666
667
668
669
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

670
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
671
672
673
674
675
676
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

677
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
678
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
679
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
680
681
682
683
684
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
685
  API_END();
686
687
}

688
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
689
690
691
  const float* grad,
  const float* hess,
  int* is_finished) {
692
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
693
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
694
695
696
697
698
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
699
  API_END();
700
701
}

702
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
703
704
705
706
707
708
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

709
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
710
711
712
713
714
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
715

716
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
717
718
719
720
721
722
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

723
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
724
725
726
727
728
729
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

730
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
731
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
732
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
733
  double* out_results) {
734
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
735
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
736
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
737
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
738
  *out_len = static_cast<int>(result_buf.size());
739
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
740
    (out_results)[i] = static_cast<double>(result_buf[i]);
741
  }
742
  API_END();
743
744
}

745
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
746
747
748
749
750
751
752
753
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

754
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
755
  int data_idx,
756
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
757
  double* out_result) {
758
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
759
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
760
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
761
  API_END();
Guolin Ke's avatar
Guolin Ke committed
762
763
}

764
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
765
  const char* data_filename,
wxchan's avatar
wxchan committed
766
767
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
768
  int num_iteration,
769
  const char* result_filename) {
770
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
771
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
772
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
773
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
774
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
775
  API_END();
776
777
}

778
779
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
780
781
782
783
784
785
786
787
788
789
790
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

791
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
792
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
793
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
794
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
795
796
797
798
799
800
801
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

802
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
803
804
  const void* indptr,
  int indptr_type,
805
806
  const int32_t* indices,
  const void* data,
807
808
809
810
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
811
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
812
  int num_iteration,
wxchan's avatar
wxchan committed
813
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
814
  double* out_result) {
815
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
816
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
817
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
818
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
819
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
820
821
822
823
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
824
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
825
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
826
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
827
828
    }
  }
wxchan's avatar
wxchan committed
829
  *out_len = nrow * num_preb_in_one_row;
830
  API_END();
Guolin Ke's avatar
Guolin Ke committed
831
}
832

833
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
834
835
836
837
838
839
840
841
842
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
843
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
844
845
846
847
848
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
849
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

878
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
879
  const void* data,
880
  int data_type,
881
882
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
883
  int is_row_major,
884
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
885
  int num_iteration,
wxchan's avatar
wxchan committed
886
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
887
  double* out_result) {
888
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
889
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
890
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
891
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
892
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
893
894
895
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
896
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
897
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
898
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
899
900
    }
  }
wxchan's avatar
wxchan committed
901
  *out_len = nrow * num_preb_in_one_row;
902
  API_END();
Guolin Ke's avatar
Guolin Ke committed
903
}
904

905
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
906
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
907
  const char* filename) {
908
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
909
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
910
911
912
913
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
  int num_iteration,
  int buffer_len,
  int* out_len,
  char* out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

929
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
930
  int num_iteration,
wxchan's avatar
wxchan committed
931
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
932
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
933
  char* out_str) {
wxchan's avatar
wxchan committed
934
935
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
936
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
937
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
938
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
939
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
940
  }
941
  API_END();
Guolin Ke's avatar
Guolin Ke committed
942
}
943

944
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
945
946
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
947
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
948
949
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
950
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
951
952
953
  API_END();
}

954
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
955
956
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
957
  double val) {
Guolin Ke's avatar
Guolin Ke committed
958
959
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
960
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
961
962
963
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
964
// ---- start of some help functions
965
966
967

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
968
  if (data_type == C_API_DTYPE_FLOAT32) {
969
970
971
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
972
        std::vector<double> ret(num_col);
973
974
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
975
          ret[i] = static_cast<double>(*(tmp_ptr + i));
976
977
978
979
980
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
981
        std::vector<double> ret(num_col);
982
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
983
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
984
985
986
987
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
988
  } else if (data_type == C_API_DTYPE_FLOAT64) {
989
990
991
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
992
        std::vector<double> ret(num_col);
993
994
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
995
          ret[i] = static_cast<double>(*(tmp_ptr + i));
996
997
998
999
1000
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1001
        std::vector<double> ret(num_col);
1002
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1003
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1004
1005
1006
1007
1008
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1009
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1010
1011
1012
1013
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1014
1015
1016
1017
1018
1019
1020
1021
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1022
        }
Guolin Ke's avatar
Guolin Ke committed
1023
1024
1025
      }
      return ret;
    };
1026
  }
Guolin Ke's avatar
Guolin Ke committed
1027
  return nullptr;
1028
1029
1030
1031
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1032
  if (data_type == C_API_DTYPE_FLOAT32) {
1033
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1034
    if (indptr_type == C_API_DTYPE_INT32) {
1035
1036
1037
1038
1039
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1040
        for (int64_t i = start; i < end; ++i) {
1041
1042
1043
1044
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1045
    } else if (indptr_type == C_API_DTYPE_INT64) {
1046
1047
1048
1049
1050
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1051
        for (int64_t i = start; i < end; ++i) {
1052
1053
1054
1055
1056
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1057
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1058
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1059
    if (indptr_type == C_API_DTYPE_INT32) {
1060
1061
1062
1063
1064
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1065
        for (int64_t i = start; i < end; ++i) {
1066
1067
1068
1069
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1070
    } else if (indptr_type == C_API_DTYPE_INT64) {
1071
1072
1073
1074
1075
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1076
        for (int64_t i = start; i < end; ++i) {
1077
1078
1079
1080
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1081
1082
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1083
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1084
1085
}

Guolin Ke's avatar
Guolin Ke committed
1086
1087
1088
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1089
  if (data_type == C_API_DTYPE_FLOAT32) {
1090
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1091
    if (col_ptr_type == C_API_DTYPE_INT32) {
1092
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1093
1094
1095
1096
1097
1098
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1099
        }
Guolin Ke's avatar
Guolin Ke committed
1100
1101
1102
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1103
      };
Guolin Ke's avatar
Guolin Ke committed
1104
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1105
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1106
1107
1108
1109
1110
1111
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1112
        }
Guolin Ke's avatar
Guolin Ke committed
1113
1114
1115
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1116
      };
Guolin Ke's avatar
Guolin Ke committed
1117
    }
Guolin Ke's avatar
Guolin Ke committed
1118
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1119
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1120
    if (col_ptr_type == C_API_DTYPE_INT32) {
1121
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1122
1123
1124
1125
1126
1127
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1128
        }
Guolin Ke's avatar
Guolin Ke committed
1129
1130
1131
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1132
      };
Guolin Ke's avatar
Guolin Ke committed
1133
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1134
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1135
1136
1137
1138
1139
1140
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1141
        }
Guolin Ke's avatar
Guolin Ke committed
1142
1143
1144
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1145
      };
Guolin Ke's avatar
Guolin Ke committed
1146
1147
1148
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1149
1150
}

Guolin Ke's avatar
Guolin Ke committed
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1162
    }
Guolin Ke's avatar
Guolin Ke committed
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1180
    }
Guolin Ke's avatar
Guolin Ke committed
1181
1182
1183
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1184
  }
Guolin Ke's avatar
Guolin Ke committed
1185
}