c_api.cpp 50.9 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
228
229
230
231
232
233
234
235
236
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
237
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
238

Guolin Ke's avatar
Guolin Ke committed
239
private:
240

wxchan's avatar
wxchan committed
241
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
242
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
243
244
245
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
246
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
247
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
248
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
249
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
251
252
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
};

}
Guolin Ke's avatar
Guolin Ke committed
256
257
258

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
269
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
275
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

291
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
292
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
293
294
}

295
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
296
297
298
                                                 const char* parameters,
                                                 const DatasetHandle reference,
                                                 DatasetHandle* out) {
299
  API_BEGIN();
wxchan's avatar
wxchan committed
300
301
302
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
303
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
304
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
305
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
306
  } else {
Guolin Ke's avatar
Guolin Ke committed
307
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
308
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
309
  }
310
  API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
                                                          int** sample_indices,
                                                          int32_t ncol,
                                                          const int* num_per_col,
                                                          int32_t num_sample_row,
                                                          int32_t num_total_row,
                                                          const char* parameters,
                                                          DatasetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr, 1, nullptr);
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
331
332
}

333

Guolin Ke's avatar
Guolin Ke committed
334
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
335
336
                                                    int64_t num_total_row,
                                                    DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
337
338
339
340
341
342
343
344
345
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
346
347
348
349
350
                                           const void* data,
                                           int data_type,
                                           int32_t nrow,
                                           int32_t ncol,
                                           int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
351
352
353
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
354
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
355
356
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
357
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
358
359
360
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
361
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
362
  }
363
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
364
365
366
367
368
369
370
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
371
372
373
374
375
376
377
378
379
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
384
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
385
386
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
387
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
388
389
390
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
391
                          static_cast<data_size_t>(start_row + i), one_row);
392
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
393
  }
394
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
395
396
397
398
399
400
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

401
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
402
403
404
405
406
407
408
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
409
  API_BEGIN();
wxchan's avatar
wxchan committed
410
411
412
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
413
  std::unique_ptr<Dataset> ret;
414
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
415
416
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
417
    Random rand(io_config.data_random_seed);
418
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
419
    auto sample_indices = rand.Sample(nrow, sample_cnt);
420
    sample_cnt = static_cast<int>(sample_indices.size());
421
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
422
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
423
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
424
      auto idx = sample_indices[i];
425
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
426
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
427
428
429
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
430
        }
Guolin Ke's avatar
Guolin Ke committed
431
432
      }
    }
433
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
434
435
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(), 
                                            Common::Vector2Ptr<int>(sample_idx).data(), 
436
437
438
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
439
  } else {
440
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
441
    ret->CreateValid(
442
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
443
  }
444
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
445
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
446
  for (int i = 0; i < nrow; ++i) {
447
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
448
    const int tid = omp_get_thread_num();
449
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
450
    ret->PushOneRow(tid, i, one_row);
451
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
452
  }
453
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
454
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
455
  *out = ret.release();
456
  API_END();
457
458
}

459
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
460
461
462
463
464
465
466
467
468
469
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t num_col,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
470
  API_BEGIN();
wxchan's avatar
wxchan committed
471
472
473
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
474
  std::unique_ptr<Dataset> ret;
475
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
476
477
478
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
479
    Random rand(io_config.data_random_seed);
480
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
481
    auto sample_indices = rand.Sample(nrow, sample_cnt);
482
    sample_cnt = static_cast<int>(sample_indices.size());
483
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
484
    std::vector<std::vector<int>> sample_idx;
485
486
487
488
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
489
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
490
491
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
492
        }
Guolin Ke's avatar
Guolin Ke committed
493
494
495
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
496
497
498
        }
      }
    }
499
    CHECK(num_col >= static_cast<int>(sample_values.size()));
500
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
501
502
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
503
504
505
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
506
  } else {
507
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
508
    ret->CreateValid(
509
      reinterpret_cast<const Dataset*>(reference));
510
  }
511
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
512
#pragma omp parallel for schedule(static)
513
  for (int i = 0; i < nindptr - 1; ++i) {
514
    OMP_LOOP_EX_BEGIN();
515
516
517
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
518
    OMP_LOOP_EX_END();
519
  }
520
  OMP_THROW_EX();
521
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
522
  *out = ret.release();
523
  API_END();
524
525
}

526
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
527
528
529
530
531
532
533
534
535
536
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
537
  API_BEGIN();
wxchan's avatar
wxchan committed
538
539
540
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
541
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
542
543
544
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
545
    Random rand(io_config.data_random_seed);
546
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
547
    auto sample_indices = rand.Sample(nrow, sample_cnt);
548
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
549
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
550
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
551
    OMP_INIT_EX();
552
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
553
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
554
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
555
556
557
558
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
559
560
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
561
562
        }
      }
563
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
564
    }
565
    OMP_THROW_EX();
566
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
567
568
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
569
570
571
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
572
  } else {
573
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
574
    ret->CreateValid(
575
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
576
  }
577
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
578
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
579
  for (int i = 0; i < ncol_ptr - 1; ++i) {
580
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
581
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
582
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
583
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
584
585
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
586
587
588
589
590
591
592
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
593
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
594
    }
595
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
596
  }
597
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
598
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
599
  *out = ret.release();
600
  API_END();
Guolin Ke's avatar
Guolin Ke committed
601
602
}

603
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
604
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
605
606
607
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
608
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
609
610
611
612
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
613
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
614
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
615
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
616
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
617
618
619
620
  *out = ret.release();
  API_END();
}

621
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
622
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
623
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
624
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
625
626
627
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
628
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
629
630
631
632
633
634
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

635
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
636
637
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
638
  int* num_feature_names) {
639
640
641
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
642
643
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
644
645
646
647
648
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

649
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
650
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
651
  delete reinterpret_cast<Dataset*>(handle);
652
  API_END();
653
654
}

655
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
656
                                             const char* filename) {
657
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
658
659
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
660
  API_END();
661
662
}

663
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
664
665
666
667
                                           const char* field_name,
                                           const void* field_data,
                                           int num_element,
                                           int type) {
668
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
669
  auto dataset = reinterpret_cast<Dataset*>(handle);
670
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
671
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
672
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
673
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
674
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
675
676
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
677
  }
678
679
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
680
681
}

682
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
683
684
685
686
                                           const char* field_name,
                                           int* out_len,
                                           const void** out_ptr,
                                           int* out_type) {
687
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
688
  auto dataset = reinterpret_cast<Dataset*>(handle);
689
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
690
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
691
    *out_type = C_API_DTYPE_FLOAT32;
692
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
693
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
694
    *out_type = C_API_DTYPE_INT32;
695
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
696
697
698
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
699
  }
700
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
701
  if (*out_ptr == nullptr) { *out_len = 0; }
702
  API_END();
703
704
}

705
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
706
                                             int* out) {
707
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
708
709
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
710
  API_END();
711
712
}

713
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
714
                                                int* out) {
715
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
716
717
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
718
  API_END();
Guolin Ke's avatar
Guolin Ke committed
719
}
720
721
722

// ---- start of booster

723
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
724
725
                                         const char* parameters,
                                         BoosterHandle* out) {
726
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
727
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
728
729
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
730
  API_END();
731
732
}

733
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
734
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
735
  int* out_num_iterations,
736
  BoosterHandle* out) {
737
  API_BEGIN();
wxchan's avatar
wxchan committed
738
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
739
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
740
  *out = ret.release();
741
  API_END();
742
743
}

744
745
746
747
748
749
750
751
752
753
754
755
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

756
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
757
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
758
  delete reinterpret_cast<Booster*>(handle);
759
  API_END();
760
761
}

762
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
763
                                        BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
764
765
766
767
768
769
770
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

771
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
772
                                               const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
773
774
775
776
777
778
779
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

780
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
781
                                                    const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
782
783
784
785
786
787
788
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

789
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
790
791
792
793
794
795
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

796
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
797
798
799
800
801
802
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

803
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
804
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
805
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
806
807
808
809
810
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
811
  API_END();
812
813
}

814
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
815
816
817
                                                      const float* grad,
                                                      const float* hess,
                                                      int* is_finished) {
818
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
819
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
820
821
822
823
824
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
825
  API_END();
826
827
}

828
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
829
830
831
832
833
834
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

835
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
836
837
838
839
840
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
841

842
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
843
844
845
846
847
848
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

849
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
850
851
852
853
854
855
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

wxchan's avatar
wxchan committed
856
857
858
859
860
861
862
863
864
865
866
867
868
869
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

870
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
871
872
873
                                          int data_idx,
                                          int* out_len,
                                          double* out_results) {
874
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
875
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
876
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
877
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
878
  *out_len = static_cast<int>(result_buf.size());
879
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
880
    (out_results)[i] = static_cast<double>(result_buf[i]);
881
  }
882
  API_END();
883
884
}

885
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
886
887
                                                int data_idx,
                                                int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
888
889
890
891
892
893
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

894
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
895
896
897
                                             int data_idx,
                                             int64_t* out_len,
                                             double* out_result) {
898
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
899
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
900
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
901
  API_END();
Guolin Ke's avatar
Guolin Ke committed
902
903
}

904
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
905
906
907
908
909
                                                 const char* data_filename,
                                                 int data_has_header,
                                                 int predict_type,
                                                 int num_iteration,
                                                 const char* result_filename) {
910
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
911
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
912
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
913
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
914
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
915
  API_END();
916
917
}

918
919
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
920
921
922
923
924
925
926
927
928
929
930
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

931
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
932
933
934
935
                                                 int num_row,
                                                 int predict_type,
                                                 int num_iteration,
                                                 int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
936
937
938
939
940
941
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

942
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
943
944
945
946
947
948
949
950
951
952
953
954
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
955
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
956
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
957
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
958
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
959
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
960
  int nrow = static_cast<int>(nindptr - 1);
961
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
962
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
963
  for (int i = 0; i < nrow; ++i) {
964
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
965
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
966
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
967
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
968
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
969
    }
970
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
971
  }
972
  OMP_THROW_EX();
wxchan's avatar
wxchan committed
973
  *out_len = nrow * num_preb_in_one_row;
974
  API_END();
Guolin Ke's avatar
Guolin Ke committed
975
}
976

977
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
978
979
980
981
982
983
984
985
986
987
988
989
                                                const void* col_ptr,
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
990
991
992
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
993
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
994
995
  int ncol = static_cast<int>(ncol_ptr - 1);

996
  Threading::For<data_size_t>(0, static_cast<data_size_t>(num_row),
997
                          [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
Guolin Ke's avatar
Guolin Ke committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

1022
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
1023
1024
1025
1026
1027
1028
1029
1030
1031
                                                const void* data,
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
1032
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1033
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1034
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
1035
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1036
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
1037
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1038
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1039
  for (int i = 0; i < nrow; ++i) {
1040
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1041
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
1042
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
1043
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
1044
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
1045
    }
1046
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1047
  }
1048
  OMP_THROW_EX();
wxchan's avatar
wxchan committed
1049
  *out_len = nrow * num_preb_in_one_row;
1050
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1051
}
1052

1053
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
1054
1055
                                            int num_iteration,
                                            const char* filename) {
1056
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1057
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1058
1059
1060
1061
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

1062
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1063
1064
1065
1066
                                                    int num_iteration,
                                                    int buffer_len,
                                                    int* out_len,
                                                    char* out_str) {
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

1077
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
1078
1079
1080
1081
                                            int num_iteration,
                                            int buffer_len,
                                            int* out_len,
                                            char* out_str) {
wxchan's avatar
wxchan committed
1082
1083
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1084
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1085
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1086
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1087
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1088
  }
1089
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1090
}
1091

1092
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1093
1094
1095
                                               int tree_idx,
                                               int leaf_idx,
                                               double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1096
1097
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1098
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1099
1100
1101
  API_END();
}

1102
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1103
1104
1105
                                               int tree_idx,
                                               int leaf_idx,
                                               double val) {
Guolin Ke's avatar
Guolin Ke committed
1106
1107
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1108
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1109
1110
1111
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1112
// ---- start of some help functions
1113
1114
1115

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1116
  if (data_type == C_API_DTYPE_FLOAT32) {
1117
1118
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1119
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1120
        std::vector<double> ret(num_col);
1121
1122
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1123
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1124
1125
1126
1127
        }
        return ret;
      };
    } else {
1128
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1129
        std::vector<double> ret(num_col);
1130
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1131
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1132
1133
1134
1135
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1136
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1137
1138
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1139
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1140
        std::vector<double> ret(num_col);
1141
1142
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1143
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1144
1145
1146
1147
        }
        return ret;
      };
    } else {
1148
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1149
        std::vector<double> ret(num_col);
1150
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1151
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1152
1153
1154
1155
1156
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1157
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1158
1159
1160
1161
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1162
1163
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1164
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1165
1166
1167
1168
1169
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1170
        }
Guolin Ke's avatar
Guolin Ke committed
1171
1172
1173
      }
      return ret;
    };
1174
  }
Guolin Ke's avatar
Guolin Ke committed
1175
  return nullptr;
1176
1177
1178
1179
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1180
  if (data_type == C_API_DTYPE_FLOAT32) {
1181
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1182
    if (indptr_type == C_API_DTYPE_INT32) {
1183
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1184
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1185
1186
1187
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1188
        for (int64_t i = start; i < end; ++i) {
1189
1190
1191
1192
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1193
    } else if (indptr_type == C_API_DTYPE_INT64) {
1194
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1195
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1196
1197
1198
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1199
        for (int64_t i = start; i < end; ++i) {
1200
1201
1202
1203
1204
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1205
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1206
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1207
    if (indptr_type == C_API_DTYPE_INT32) {
1208
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1209
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1210
1211
1212
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1213
        for (int64_t i = start; i < end; ++i) {
1214
1215
1216
1217
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1218
    } else if (indptr_type == C_API_DTYPE_INT64) {
1219
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1220
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1221
1222
1223
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1224
        for (int64_t i = start; i < end; ++i) {
1225
1226
1227
1228
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1229
1230
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1231
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1232
1233
}

Guolin Ke's avatar
Guolin Ke committed
1234
1235
1236
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1237
  if (data_type == C_API_DTYPE_FLOAT32) {
1238
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1239
    if (col_ptr_type == C_API_DTYPE_INT32) {
1240
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1241
1242
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1243
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1244
1245
1246
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1247
        }
Guolin Ke's avatar
Guolin Ke committed
1248
1249
1250
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1251
      };
Guolin Ke's avatar
Guolin Ke committed
1252
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1253
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1254
1255
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1256
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1257
1258
1259
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1260
        }
Guolin Ke's avatar
Guolin Ke committed
1261
1262
1263
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1264
      };
Guolin Ke's avatar
Guolin Ke committed
1265
    }
Guolin Ke's avatar
Guolin Ke committed
1266
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1267
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1268
    if (col_ptr_type == C_API_DTYPE_INT32) {
1269
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1270
1271
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1272
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1273
1274
1275
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1276
        }
Guolin Ke's avatar
Guolin Ke committed
1277
1278
1279
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1280
      };
Guolin Ke's avatar
Guolin Ke committed
1281
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1282
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1283
1284
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1285
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1286
1287
1288
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1289
        }
Guolin Ke's avatar
Guolin Ke committed
1290
1291
1292
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1293
      };
Guolin Ke's avatar
Guolin Ke committed
1294
1295
1296
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1297
1298
}

Guolin Ke's avatar
Guolin Ke committed
1299
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1300
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1301
1302
1303
1304
1305
1306
1307
1308
1309
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1310
    }
Guolin Ke's avatar
Guolin Ke committed
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1328
    }
Guolin Ke's avatar
Guolin Ke committed
1329
1330
1331
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1332
  }
Guolin Ke's avatar
Guolin Ke committed
1333
}