c_api.cpp 53.6 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
228
229
230
231
232
233
234
235
236
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
237
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
238

Guolin Ke's avatar
Guolin Ke committed
239
private:
240

wxchan's avatar
wxchan committed
241
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
242
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
243
244
245
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
246
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
247
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
248
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
249
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
251
252
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
};

}
Guolin Ke's avatar
Guolin Ke committed
256
257
258

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
269
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
275
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

291
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
292
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
293
294
}

295
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
296
297
298
                                                 const char* parameters,
                                                 const DatasetHandle reference,
                                                 DatasetHandle* out) {
299
  API_BEGIN();
wxchan's avatar
wxchan committed
300
301
302
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
303
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
304
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
305
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
306
  } else {
Guolin Ke's avatar
Guolin Ke committed
307
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
308
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
309
  }
310
  API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
                                                       int data_type,
                                                       int32_t num_sample_row,
                                                       int32_t ncol,
                                                       int32_t num_total_row,
                                                       const char* parameters,
                                                       DatasetHandle* out) {
  if (num_sample_row == num_total_row) {
    return LGBM_DatasetCreateFromMat(data, data_type, num_total_row, ncol, 1, parameters, nullptr, out);
  } else {
    API_BEGIN();
    auto param = ConfigBase::Str2Map(parameters);
    IOConfig io_config;
    io_config.Set(param);
    auto get_row_fun = RowFunctionFromDenseMatric(data, num_sample_row, ncol, data_type, 1);
    std::vector<std::vector<double>> sample_values(ncol);
    std::vector<std::vector<int>> sample_idx(ncol);
    for (int i = 0; i < num_sample_row; ++i) {
      auto row = get_row_fun(i);
      for (size_t idx = 0; idx < row.size(); ++idx) {
        if (std::fabs(row[idx]) > kEpsilon) {
          sample_values[idx].emplace_back(row[idx]);
          sample_idx[idx].emplace_back(i);
        }
      }
    }
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
    *out = loader.CostructFromSampleData(sample_values, sample_idx,
                                         num_sample_row,
                                         static_cast<data_size_t>(num_total_row));
    API_END();
  }
}

Guolin Ke's avatar
Guolin Ke committed
347
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
348
349
350
351
352
353
354
355
356
357
                                                       int indptr_type,
                                                       const int32_t* indices,
                                                       const void* data,
                                                       int data_type,
                                                       int64_t nindptr,
                                                       int64_t n_sample_elem,
                                                       int64_t num_col,
                                                       int64_t num_total_row,
                                                       const char* parameters,
                                                       DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
358
359
  if (nindptr - 1 == num_total_row) {
    return LGBM_DatasetCreateFromCSR(indptr, indptr_type, indices, data,
360
                                     data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out);
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
  } else {
    API_BEGIN();
    auto param = ConfigBase::Str2Map(parameters);
    IOConfig io_config;
    io_config.Set(param);
    auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, n_sample_elem);
    int32_t num_sample_row = static_cast<int32_t>(nindptr - 1);
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    for (int i = 0; i < num_sample_row; ++i) {
      auto row = get_row_fun(i);
      for (std::pair<int, double>& inner_data : row) {
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
        }
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(i);
        }
      }
    }
    CHECK(num_col >= static_cast<int>(sample_values.size()));
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
    *out = loader.CostructFromSampleData(sample_values, sample_idx,
386
387
                                         num_sample_row,
                                         static_cast<data_size_t>(num_total_row));
Guolin Ke's avatar
Guolin Ke committed
388
389
390
391
392
    API_END();
  }
}

LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
393
394
                                                    int64_t num_total_row,
                                                    DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
395
396
397
398
399
400
401
402
403
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
404
405
406
407
408
                                           const void* data,
                                           int data_type,
                                           int32_t nrow,
                                           int32_t ncol,
                                           int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
  }
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
425
426
427
428
429
430
431
432
433
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
434
435
436
437
438
439
440
441
442
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
443
                          static_cast<data_size_t>(start_row + i), one_row);
Guolin Ke's avatar
Guolin Ke committed
444
445
446
447
448
449
450
  }
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

451
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
452
453
454
455
456
457
458
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
459
  API_BEGIN();
wxchan's avatar
wxchan committed
460
461
462
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
463
  std::unique_ptr<Dataset> ret;
464
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
465
466
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
467
468
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
469
    auto sample_indices = rand.Sample(nrow, sample_cnt);
470
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
471
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
472
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
473
      auto idx = sample_indices[i];
474
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
475
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
476
477
478
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
479
        }
Guolin Ke's avatar
Guolin Ke committed
480
481
      }
    }
482
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
483
    ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
484
  } else {
485
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
486
    ret->CreateValid(
487
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
488
489
  }

Guolin Ke's avatar
Guolin Ke committed
490
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
491
492
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
493
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
494
495
496
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
497
  *out = ret.release();
498
  API_END();
499
500
}

501
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
502
503
504
505
506
507
508
509
510
511
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t num_col,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
512
  API_BEGIN();
wxchan's avatar
wxchan committed
513
514
515
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
516
  std::unique_ptr<Dataset> ret;
517
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
518
519
520
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
521
522
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
523
524
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
525
    std::vector<std::vector<int>> sample_idx;
526
527
528
529
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
530
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
531
532
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
533
        }
Guolin Ke's avatar
Guolin Ke committed
534
535
536
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
537
538
539
        }
      }
    }
540
    CHECK(num_col >= static_cast<int>(sample_values.size()));
541
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
542
    ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow));
543
  } else {
544
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
545
    ret->CreateValid(
546
      reinterpret_cast<const Dataset*>(reference));
547
548
  }

Guolin Ke's avatar
Guolin Ke committed
549
#pragma omp parallel for schedule(static)
550
551
552
553
554
555
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
556
  *out = ret.release();
557
  API_END();
558
559
}

560
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
561
562
563
564
565
566
567
568
569
570
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
571
  API_BEGIN();
wxchan's avatar
wxchan committed
572
573
574
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
575
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
576
577
578
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
579
580
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
581
582
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
583
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
584
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
585
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
586
587
588
589
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
590
591
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
592
593
        }
      }
Guolin Ke's avatar
Guolin Ke committed
594
    }
595
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
596
    ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
597
  } else {
598
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
599
    ret->CreateValid(
600
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
601
602
  }

Guolin Ke's avatar
Guolin Ke committed
603
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
604
605
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
606
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
607
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
608
609
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
610
611
612
613
614
615
616
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
617
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
618
    }
Guolin Ke's avatar
Guolin Ke committed
619
620
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
621
  *out = ret.release();
622
  API_END();
Guolin Ke's avatar
Guolin Ke committed
623
624
}

625
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
626
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
627
628
629
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
630
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
631
632
633
634
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
635
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
636
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
637
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
638
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
639
640
641
642
  *out = ret.release();
  API_END();
}

643
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
644
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
645
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
646
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
647
648
649
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
650
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
651
652
653
654
655
656
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

657
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
658
659
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
660
  int* num_feature_names) {
661
662
663
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
664
665
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
666
667
668
669
670
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

671
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
672
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
673
  delete reinterpret_cast<Dataset*>(handle);
674
  API_END();
675
676
}

677
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
678
                                             const char* filename) {
679
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
680
681
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
682
  API_END();
683
684
}

685
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
686
687
688
689
                                           const char* field_name,
                                           const void* field_data,
                                           int num_element,
                                           int type) {
690
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
691
  auto dataset = reinterpret_cast<Dataset*>(handle);
692
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
693
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
694
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
695
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
696
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
697
698
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
699
  }
700
701
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
702
703
}

704
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
705
706
707
708
                                           const char* field_name,
                                           int* out_len,
                                           const void** out_ptr,
                                           int* out_type) {
709
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
710
  auto dataset = reinterpret_cast<Dataset*>(handle);
711
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
712
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
713
    *out_type = C_API_DTYPE_FLOAT32;
714
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
715
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
716
    *out_type = C_API_DTYPE_INT32;
717
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
718
719
720
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
721
  }
722
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
723
  if (*out_ptr == nullptr) { *out_len = 0; }
724
  API_END();
725
726
}

727
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
728
                                             int* out) {
729
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
730
731
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
732
  API_END();
733
734
}

735
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
736
                                                int* out) {
737
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
738
739
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
740
  API_END();
Guolin Ke's avatar
Guolin Ke committed
741
}
742
743
744

// ---- start of booster

745
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
746
747
                                         const char* parameters,
                                         BoosterHandle* out) {
748
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
749
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
750
751
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
752
  API_END();
753
754
}

755
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
756
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
757
  int* out_num_iterations,
758
  BoosterHandle* out) {
759
  API_BEGIN();
wxchan's avatar
wxchan committed
760
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
761
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
762
  *out = ret.release();
763
  API_END();
764
765
}

766
767
768
769
770
771
772
773
774
775
776
777
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

778
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
779
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
780
  delete reinterpret_cast<Booster*>(handle);
781
  API_END();
782
783
}

784
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
785
                                        BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
786
787
788
789
790
791
792
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

793
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
794
                                               const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
795
796
797
798
799
800
801
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

802
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
803
                                                    const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
804
805
806
807
808
809
810
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

811
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
812
813
814
815
816
817
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

818
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
819
820
821
822
823
824
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

825
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
826
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
827
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
828
829
830
831
832
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
833
  API_END();
834
835
}

836
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
837
838
839
                                                      const float* grad,
                                                      const float* hess,
                                                      int* is_finished) {
840
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
841
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
842
843
844
845
846
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
847
  API_END();
848
849
}

850
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
851
852
853
854
855
856
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

857
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
858
859
860
861
862
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
863

864
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
865
866
867
868
869
870
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

871
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
872
873
874
875
876
877
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

wxchan's avatar
wxchan committed
878
879
880
881
882
883
884
885
886
887
888
889
890
891
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

892
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
893
894
895
                                          int data_idx,
                                          int* out_len,
                                          double* out_results) {
896
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
897
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
898
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
899
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
900
  *out_len = static_cast<int>(result_buf.size());
901
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
902
    (out_results)[i] = static_cast<double>(result_buf[i]);
903
  }
904
  API_END();
905
906
}

907
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
908
909
                                                int data_idx,
                                                int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
910
911
912
913
914
915
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

916
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
917
918
919
                                             int data_idx,
                                             int64_t* out_len,
                                             double* out_result) {
920
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
921
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
922
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
923
  API_END();
Guolin Ke's avatar
Guolin Ke committed
924
925
}

926
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
927
928
929
930
931
                                                 const char* data_filename,
                                                 int data_has_header,
                                                 int predict_type,
                                                 int num_iteration,
                                                 const char* result_filename) {
932
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
933
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
934
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
935
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
936
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
937
  API_END();
938
939
}

940
941
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
942
943
944
945
946
947
948
949
950
951
952
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

953
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
954
955
956
957
                                                 int num_row,
                                                 int predict_type,
                                                 int num_iteration,
                                                 int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
958
959
960
961
962
963
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

964
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
965
966
967
968
969
970
971
972
973
974
975
976
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
977
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
978
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
979
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
980
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
981
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
982
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
983
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
984
985
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
986
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
987
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
988
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
989
990
    }
  }
wxchan's avatar
wxchan committed
991
  *out_len = nrow * num_preb_in_one_row;
992
  API_END();
Guolin Ke's avatar
Guolin Ke committed
993
}
994

995
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                                                const void* col_ptr,
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1008
1009
1010
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
1011
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1012
1013
1014
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
1015
                          [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
Guolin Ke's avatar
Guolin Ke committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

1040
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
1041
1042
1043
1044
1045
1046
1047
1048
1049
                                                const void* data,
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
1050
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1051
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1052
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
1053
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1054
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1055
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1056
1057
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
1058
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
1059
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
1060
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
1061
1062
    }
  }
wxchan's avatar
wxchan committed
1063
  *out_len = nrow * num_preb_in_one_row;
1064
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1065
}
1066

1067
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
1068
1069
                                            int num_iteration,
                                            const char* filename) {
1070
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1071
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1072
1073
1074
1075
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

1076
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1077
1078
1079
1080
                                                    int num_iteration,
                                                    int buffer_len,
                                                    int* out_len,
                                                    char* out_str) {
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

1091
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
1092
1093
1094
1095
                                            int num_iteration,
                                            int buffer_len,
                                            int* out_len,
                                            char* out_str) {
wxchan's avatar
wxchan committed
1096
1097
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1098
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1099
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1100
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1101
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1102
  }
1103
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1104
}
1105

1106
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1107
1108
1109
                                               int tree_idx,
                                               int leaf_idx,
                                               double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1110
1111
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1112
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1113
1114
1115
  API_END();
}

1116
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1117
1118
1119
                                               int tree_idx,
                                               int leaf_idx,
                                               double val) {
Guolin Ke's avatar
Guolin Ke committed
1120
1121
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1122
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1123
1124
1125
  API_END();
}

1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175

LIGHTGBM_C_EXPORT int LGBM_AllocateArray(int64_t len, int type, ArrayHandle* out) {
  API_BEGIN();
  if (type == C_API_DTYPE_FLOAT32) {
    *out = new float[len];
  } else if (type == C_API_DTYPE_FLOAT64) {
    *out = new double[len];
  } else if (type == C_API_DTYPE_INT32) {
    *out = new int32_t[len];
  } else if (type == C_API_DTYPE_INT64) {
    *out = new int64_t[len];
  }
  API_END();
}

template<typename T>
void Copy(T* dst, const T* src, int64_t len) {
  for (int64_t i = 0; i < len; ++i) {
    dst[i] = src[i];
  }
}

LIGHTGBM_C_EXPORT int LGBM_CopyToArray(ArrayHandle arr, int type, int64_t start_idx, const void* src, int64_t len) {
  API_BEGIN();
  if (type == C_API_DTYPE_FLOAT32) {
    Copy<float>(static_cast<float*>(arr) + start_idx, static_cast<const float*>(src), len);
  } else if (type == C_API_DTYPE_FLOAT64) {
    Copy<double>(static_cast<double*>(arr) + start_idx, static_cast<const double*>(src), len);
  } else if (type == C_API_DTYPE_INT32) {
    Copy<int32_t>(static_cast<int32_t*>(arr) + start_idx, static_cast<const int32_t*>(src), len);
  } else if (type == C_API_DTYPE_INT64) {
    Copy<int64_t>(static_cast<int64_t*>(arr) + start_idx, static_cast<const int64_t*>(src), len);
  }
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_FreeArray(ArrayHandle arr, int type) {
  API_BEGIN();
  if (type == C_API_DTYPE_FLOAT32) {
    delete[] static_cast<float*>(arr);
  } else if (type == C_API_DTYPE_FLOAT64) {
    delete[] static_cast<double*>(arr);
  } else if (type == C_API_DTYPE_INT32) {
    delete[] static_cast<int32_t*>(arr);
  } else if (type == C_API_DTYPE_INT64) {
    delete[] static_cast<int64_t*>(arr);
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1176
// ---- start of some help functions
1177
1178
1179

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1180
  if (data_type == C_API_DTYPE_FLOAT32) {
1181
1182
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1183
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1184
        std::vector<double> ret(num_col);
1185
1186
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1187
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1188
1189
1190
1191
        }
        return ret;
      };
    } else {
1192
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1193
        std::vector<double> ret(num_col);
1194
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1195
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1196
1197
1198
1199
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1200
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1201
1202
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1203
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1204
        std::vector<double> ret(num_col);
1205
1206
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1207
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1208
1209
1210
1211
        }
        return ret;
      };
    } else {
1212
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1213
        std::vector<double> ret(num_col);
1214
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1215
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1216
1217
1218
1219
1220
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1221
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1222
1223
1224
1225
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1226
1227
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1228
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1229
1230
1231
1232
1233
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1234
        }
Guolin Ke's avatar
Guolin Ke committed
1235
1236
1237
      }
      return ret;
    };
1238
  }
Guolin Ke's avatar
Guolin Ke committed
1239
  return nullptr;
1240
1241
1242
1243
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1244
  if (data_type == C_API_DTYPE_FLOAT32) {
1245
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1246
    if (indptr_type == C_API_DTYPE_INT32) {
1247
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1248
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1249
1250
1251
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1252
        for (int64_t i = start; i < end; ++i) {
1253
1254
1255
1256
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1257
    } else if (indptr_type == C_API_DTYPE_INT64) {
1258
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1259
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1260
1261
1262
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1263
        for (int64_t i = start; i < end; ++i) {
1264
1265
1266
1267
1268
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1269
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1270
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1271
    if (indptr_type == C_API_DTYPE_INT32) {
1272
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1273
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1274
1275
1276
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1277
        for (int64_t i = start; i < end; ++i) {
1278
1279
1280
1281
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1282
    } else if (indptr_type == C_API_DTYPE_INT64) {
1283
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1284
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1285
1286
1287
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1288
        for (int64_t i = start; i < end; ++i) {
1289
1290
1291
1292
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1293
1294
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1295
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1296
1297
}

Guolin Ke's avatar
Guolin Ke committed
1298
1299
1300
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1301
  if (data_type == C_API_DTYPE_FLOAT32) {
1302
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1303
    if (col_ptr_type == C_API_DTYPE_INT32) {
1304
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1305
1306
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1307
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1308
1309
1310
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1311
        }
Guolin Ke's avatar
Guolin Ke committed
1312
1313
1314
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1315
      };
Guolin Ke's avatar
Guolin Ke committed
1316
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1317
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1318
1319
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1320
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1321
1322
1323
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1324
        }
Guolin Ke's avatar
Guolin Ke committed
1325
1326
1327
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1328
      };
Guolin Ke's avatar
Guolin Ke committed
1329
    }
Guolin Ke's avatar
Guolin Ke committed
1330
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1331
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1332
    if (col_ptr_type == C_API_DTYPE_INT32) {
1333
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1334
1335
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1336
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1337
1338
1339
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1340
        }
Guolin Ke's avatar
Guolin Ke committed
1341
1342
1343
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1344
      };
Guolin Ke's avatar
Guolin Ke committed
1345
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1346
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1347
1348
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1349
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1350
1351
1352
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1353
        }
Guolin Ke's avatar
Guolin Ke committed
1354
1355
1356
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1357
      };
Guolin Ke's avatar
Guolin Ke committed
1358
1359
1360
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1361
1362
}

Guolin Ke's avatar
Guolin Ke committed
1363
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1364
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1365
1366
1367
1368
1369
1370
1371
1372
1373
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1374
    }
Guolin Ke's avatar
Guolin Ke committed
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1392
    }
Guolin Ke's avatar
Guolin Ke committed
1393
1394
1395
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1396
  }
Guolin Ke's avatar
Guolin Ke committed
1397
}