c_api.cpp 39.3 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
39
40
41
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
55
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
95
96
97
98
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
228
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
229

Guolin Ke's avatar
Guolin Ke committed
230
private:
231

wxchan's avatar
wxchan committed
232
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
233
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
234
235
236
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
237
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
238
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
239
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
240
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
241
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
242
243
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
244
245
246
};

}
Guolin Ke's avatar
Guolin Ke committed
247
248
249

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

282
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
283
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
284
285
}

286
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
287
  const char* parameters,
288
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
289
  DatasetHandle* out) {
290
  API_BEGIN();
wxchan's avatar
wxchan committed
291
292
293
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
294
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
295
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
296
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
297
  } else {
Guolin Ke's avatar
Guolin Ke committed
298
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
299
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
300
  }
301
  API_END();
Guolin Ke's avatar
Guolin Ke committed
302
303
}

304
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
305
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
310
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
311
  DatasetHandle* out) {
312
  API_BEGIN();
wxchan's avatar
wxchan committed
313
314
315
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
316
  std::unique_ptr<Dataset> ret;
317
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
318
319
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
320
321
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
322
    auto sample_indices = rand.Sample(nrow, sample_cnt);
323
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
324
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
325
      auto idx = sample_indices[i];
326
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
327
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
331
332
      }
    }
333
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
334
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
335
  } else {
336
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
337
    ret->CopyFeatureMapperFrom(
338
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
339
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
344
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
345
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
346
347
348
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
349
  *out = ret.release();
350
  API_END();
351
352
}

353
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
354
  int indptr_type,
355
356
  const int32_t* indices,
  const void* data,
357
358
359
360
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
361
  const char* parameters,
362
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
363
  DatasetHandle* out) {
364
  API_BEGIN();
wxchan's avatar
wxchan committed
365
366
367
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
368
  std::unique_ptr<Dataset> ret;
369
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
370
371
372
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
373
374
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
375
376
377
378
379
380
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
381
382
383
384
385
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          // if need expand feature set
          size_t need_size = inner_data.first - sample_values.size() + 1;
          for (size_t j = 0; j < need_size; ++j) {
            sample_values.emplace_back();
386
          }
387
388
        }
        if (std::fabs(inner_data.second) > 1e-15) {
Guolin Ke's avatar
Guolin Ke committed
389
390
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
391
392
393
        }
      }
    }
394
    CHECK(num_col >= static_cast<int>(sample_values.size()));
395
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
396
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
397
  } else {
398
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
399
    ret->CopyFeatureMapperFrom(
400
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
401
      io_config.is_enable_sparse);
402
403
404
405
406
407
408
409
410
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
411
  *out = ret.release();
412
  API_END();
413
414
}

415
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
416
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
417
418
  const int32_t* indices,
  const void* data,
419
420
421
422
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
423
  const char* parameters,
424
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
425
  DatasetHandle* out) {
426
  API_BEGIN();
wxchan's avatar
wxchan committed
427
428
429
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
430
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
431
432
433
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
434
435
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
436
437
438
439
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
440
441
442
443
444
445
446
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
          sample_values[i].push_back(val);
        }
      }
Guolin Ke's avatar
Guolin Ke committed
447
    }
448
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
449
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
450
  } else {
451
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
452
    ret->CopyFeatureMapperFrom(
453
      reinterpret_cast<const Dataset*>(reference),
wxchan's avatar
wxchan committed
454
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
455
456
457
458
459
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
464
465
466
467
468
469
470
    int feature_idx = ret->GetInnerFeatureIndex(i);
    if (feature_idx < 0) { continue; }
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
      ret->FeatureAt(feature_idx)->PushData(tid, row_idx, pair.second);
    }
Guolin Ke's avatar
Guolin Ke committed
471
472
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
473
  *out = ret.release();
474
  API_END();
Guolin Ke's avatar
Guolin Ke committed
475
476
}

477
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
478
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
479
480
481
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
482
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
483
484
485
486
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
487
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
wxchan's avatar
wxchan committed
488
489
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
490
      num_used_row_indices,
wxchan's avatar
wxchan committed
491
492
493
494
495
496
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

497
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
498
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
499
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
500
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
501
502
503
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
504
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
505
506
507
508
509
510
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

511
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
512
513
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
514
  int* num_feature_names) {
515
516
517
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
518
519
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
520
521
522
523
524
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

525
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
526
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
527
  delete reinterpret_cast<Dataset*>(handle);
528
  API_END();
529
530
}

531
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
532
  const char* filename) {
533
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
534
535
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
536
  API_END();
537
538
}

539
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
540
541
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
542
  int num_element,
543
  int type) {
544
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
545
  auto dataset = reinterpret_cast<Dataset*>(handle);
546
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
547
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
548
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
549
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
550
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
551
552
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
553
  }
554
555
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
556
557
}

558
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
559
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
560
  int* out_len,
561
562
  const void** out_ptr,
  int* out_type) {
563
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
564
  auto dataset = reinterpret_cast<Dataset*>(handle);
565
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
566
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
567
    *out_type = C_API_DTYPE_FLOAT32;
568
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
569
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
570
    *out_type = C_API_DTYPE_INT32;
571
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
572
573
574
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
575
  }
576
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
577
  if (*out_ptr == nullptr) { *out_len = 0; }
578
  API_END();
579
580
}

581
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
582
  int* out) {
583
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
584
585
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
586
  API_END();
587
588
}

589
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
590
  int* out) {
591
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
592
593
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
594
  API_END();
Guolin Ke's avatar
Guolin Ke committed
595
}
596
597
598

// ---- start of booster

599
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
600
601
  const char* parameters,
  BoosterHandle* out) {
602
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
603
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
604
605
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
606
  API_END();
607
608
}

609
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
610
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
611
  int* out_num_iterations,
612
  BoosterHandle* out) {
613
  API_BEGIN();
wxchan's avatar
wxchan committed
614
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
615
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
616
  *out = ret.release();
617
  API_END();
618
619
}

620
621
622
623
624
625
626
627
628
629
630
631
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

632
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
633
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
634
  delete reinterpret_cast<Booster*>(handle);
635
  API_END();
636
637
}

638
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
639
640
641
642
643
644
645
646
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

647
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
648
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
649
650
651
652
653
654
655
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

656
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
657
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
658
659
660
661
662
663
664
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

665
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
666
667
668
669
670
671
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

672
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
673
674
675
676
677
678
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

679
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
680
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
681
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
682
683
684
685
686
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
687
  API_END();
688
689
}

690
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
691
692
693
  const float* grad,
  const float* hess,
  int* is_finished) {
694
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
695
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
696
697
698
699
700
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
701
  API_END();
702
703
}

704
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
705
706
707
708
709
710
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

711
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
712
713
714
715
716
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
717

718
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
719
720
721
722
723
724
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

725
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
726
727
728
729
730
731
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

732
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
733
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
734
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
735
  double* out_results) {
736
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
737
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
738
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
739
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
740
  *out_len = static_cast<int>(result_buf.size());
741
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
742
    (out_results)[i] = static_cast<double>(result_buf[i]);
743
  }
744
  API_END();
745
746
}

747
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
748
749
750
751
752
753
754
755
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

756
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
757
  int data_idx,
758
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
759
  double* out_result) {
760
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
761
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
762
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
763
  API_END();
Guolin Ke's avatar
Guolin Ke committed
764
765
}

766
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
767
  const char* data_filename,
wxchan's avatar
wxchan committed
768
769
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
770
  int num_iteration,
771
  const char* result_filename) {
772
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
773
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
774
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
775
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
776
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
777
  API_END();
778
779
}

780
781
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
782
783
784
785
786
787
788
789
790
791
792
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

793
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
794
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
795
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
796
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
797
798
799
800
801
802
803
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

804
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
805
806
  const void* indptr,
  int indptr_type,
807
808
  const int32_t* indices,
  const void* data,
809
810
811
812
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
813
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
814
  int num_iteration,
wxchan's avatar
wxchan committed
815
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
816
  double* out_result) {
817
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
818
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
819
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
820
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
821
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
822
823
824
825
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
826
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
827
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
828
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
829
830
    }
  }
wxchan's avatar
wxchan committed
831
  *out_len = nrow * num_preb_in_one_row;
832
  API_END();
Guolin Ke's avatar
Guolin Ke committed
833
}
834

835
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
836
837
838
839
840
841
842
843
844
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
845
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
846
847
848
849
850
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
851
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

880
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
881
  const void* data,
882
  int data_type,
883
884
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
885
  int is_row_major,
886
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
887
  int num_iteration,
wxchan's avatar
wxchan committed
888
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
889
  double* out_result) {
890
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
891
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
892
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
893
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
894
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
895
896
897
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
898
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
899
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
900
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
901
902
    }
  }
wxchan's avatar
wxchan committed
903
  *out_len = nrow * num_preb_in_one_row;
904
  API_END();
Guolin Ke's avatar
Guolin Ke committed
905
}
906

907
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
908
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
909
  const char* filename) {
910
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
911
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
912
913
914
915
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
  int num_iteration,
  int buffer_len,
  int* out_len,
  char* out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

931
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
932
  int num_iteration,
wxchan's avatar
wxchan committed
933
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
934
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
935
  char* out_str) {
wxchan's avatar
wxchan committed
936
937
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
938
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
939
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
940
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
941
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
942
  }
943
  API_END();
Guolin Ke's avatar
Guolin Ke committed
944
}
945

946
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
947
948
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
949
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
950
951
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
952
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
953
954
955
  API_END();
}

956
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
957
958
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
959
  double val) {
Guolin Ke's avatar
Guolin Ke committed
960
961
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
962
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
963
964
965
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
966
// ---- start of some help functions
967
968
969

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
970
  if (data_type == C_API_DTYPE_FLOAT32) {
971
972
973
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
974
        std::vector<double> ret(num_col);
975
976
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
977
          ret[i] = static_cast<double>(*(tmp_ptr + i));
978
979
980
981
982
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
983
        std::vector<double> ret(num_col);
984
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
985
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
986
987
988
989
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
990
  } else if (data_type == C_API_DTYPE_FLOAT64) {
991
992
993
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
994
        std::vector<double> ret(num_col);
995
996
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
997
          ret[i] = static_cast<double>(*(tmp_ptr + i));
998
999
1000
1001
1002
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1003
        std::vector<double> ret(num_col);
1004
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1005
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1006
1007
1008
1009
1010
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1011
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1012
1013
1014
1015
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1016
1017
1018
1019
1020
1021
1022
1023
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1024
        }
Guolin Ke's avatar
Guolin Ke committed
1025
1026
1027
      }
      return ret;
    };
1028
  }
Guolin Ke's avatar
Guolin Ke committed
1029
  return nullptr;
1030
1031
1032
1033
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1034
  if (data_type == C_API_DTYPE_FLOAT32) {
1035
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1036
    if (indptr_type == C_API_DTYPE_INT32) {
1037
1038
1039
1040
1041
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1042
        for (int64_t i = start; i < end; ++i) {
1043
1044
1045
1046
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1047
    } else if (indptr_type == C_API_DTYPE_INT64) {
1048
1049
1050
1051
1052
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1053
        for (int64_t i = start; i < end; ++i) {
1054
1055
1056
1057
1058
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1059
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1060
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1061
    if (indptr_type == C_API_DTYPE_INT32) {
1062
1063
1064
1065
1066
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1067
        for (int64_t i = start; i < end; ++i) {
1068
1069
1070
1071
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1072
    } else if (indptr_type == C_API_DTYPE_INT64) {
1073
1074
1075
1076
1077
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1078
        for (int64_t i = start; i < end; ++i) {
1079
1080
1081
1082
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1083
1084
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1085
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1086
1087
}

Guolin Ke's avatar
Guolin Ke committed
1088
1089
1090
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1091
  if (data_type == C_API_DTYPE_FLOAT32) {
1092
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1093
    if (col_ptr_type == C_API_DTYPE_INT32) {
1094
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1095
1096
1097
1098
1099
1100
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1101
        }
Guolin Ke's avatar
Guolin Ke committed
1102
1103
1104
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1105
      };
Guolin Ke's avatar
Guolin Ke committed
1106
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1107
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1108
1109
1110
1111
1112
1113
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1114
        }
Guolin Ke's avatar
Guolin Ke committed
1115
1116
1117
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1118
      };
Guolin Ke's avatar
Guolin Ke committed
1119
    }
Guolin Ke's avatar
Guolin Ke committed
1120
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1121
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1122
    if (col_ptr_type == C_API_DTYPE_INT32) {
1123
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1124
1125
1126
1127
1128
1129
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1130
        }
Guolin Ke's avatar
Guolin Ke committed
1131
1132
1133
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1134
      };
Guolin Ke's avatar
Guolin Ke committed
1135
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1136
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1137
1138
1139
1140
1141
1142
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1143
        }
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1147
      };
Guolin Ke's avatar
Guolin Ke committed
1148
1149
1150
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1151
1152
}

Guolin Ke's avatar
Guolin Ke committed
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1164
    }
Guolin Ke's avatar
Guolin Ke committed
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1182
    }
Guolin Ke's avatar
Guolin Ke committed
1183
1184
1185
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1186
  }
Guolin Ke's avatar
Guolin Ke committed
1187
}