c_api.cpp 43.6 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
wxchan's avatar
wxchan committed
39
40
41
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
55
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
wxchan's avatar
wxchan committed
95
96
97
98
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
124
125
126

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config));
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
228
229
230
231
232
233
234
235
236
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
237
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
238

Guolin Ke's avatar
Guolin Ke committed
239
private:
240

wxchan's avatar
wxchan committed
241
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
242
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
243
244
245
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
246
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
247
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
248
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
249
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
251
252
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
};

}
Guolin Ke's avatar
Guolin Ke committed
256
257
258

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t nindptr, int64_t nelem);

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
    const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

291
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
292
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
293
294
}

295
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
296
  const char* parameters,
297
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
298
  DatasetHandle* out) {
299
  API_BEGIN();
wxchan's avatar
wxchan committed
300
301
302
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
303
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
304
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
305
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
306
  } else {
Guolin Ke's avatar
Guolin Ke committed
307
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
308
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
309
  }
310
  API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
  int indptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t nindptr,
  int64_t n_sample_elem,
  int64_t num_col,
  int64_t num_total_row,
  const char* parameters,
  DatasetHandle* out) {
  if (nindptr - 1 == num_total_row) {
    return LGBM_DatasetCreateFromCSR(indptr, indptr_type, indices, data,
      data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out);
  } else {
    API_BEGIN();
    auto param = ConfigBase::Str2Map(parameters);
    IOConfig io_config;
    io_config.Set(param);
    auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, n_sample_elem);
    int32_t num_sample_row = static_cast<int32_t>(nindptr - 1);
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    for (int i = 0; i < num_sample_row; ++i) {
      auto row = get_row_fun(i);
      for (std::pair<int, double>& inner_data : row) {
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
        }
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(i);
        }
      }
    }
    CHECK(num_col >= static_cast<int>(sample_values.size()));
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
    *out = loader.CostructFromSampleData(sample_values, sample_idx,
      num_sample_row,
      static_cast<data_size_t>(num_total_row));
    API_END();
  }
}

LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
  int64_t num_total_row,
  DatasetHandle* out) {
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
  const void* data,
  int data_type,
  int32_t nrow,
  int32_t ncol,
  int32_t start_row) {
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
  }
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
  const void* indptr,
  int indptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
  int64_t start_row) {
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
      static_cast<data_size_t>(start_row + i), one_row);
  }
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

417
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
418
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
419
420
421
422
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
423
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
424
  DatasetHandle* out) {
425
  API_BEGIN();
wxchan's avatar
wxchan committed
426
427
428
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
429
  std::unique_ptr<Dataset> ret;
430
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
431
432
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
433
434
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
435
    auto sample_indices = rand.Sample(nrow, sample_cnt);
436
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
437
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
438
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
439
      auto idx = sample_indices[i];
440
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
441
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
442
443
444
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
445
        }
Guolin Ke's avatar
Guolin Ke committed
446
447
      }
    }
448
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
449
    ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
450
  } else {
451
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
452
    ret->CreateValid(
453
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
454
455
  }

Guolin Ke's avatar
Guolin Ke committed
456
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
457
458
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
459
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
460
461
462
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
463
  *out = ret.release();
464
  API_END();
465
466
}

467
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
468
  int indptr_type,
469
470
  const int32_t* indices,
  const void* data,
471
472
473
474
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
475
  const char* parameters,
476
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
477
  DatasetHandle* out) {
478
  API_BEGIN();
wxchan's avatar
wxchan committed
479
480
481
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
482
  std::unique_ptr<Dataset> ret;
483
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
484
485
486
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
487
488
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
489
490
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
491
    std::vector<std::vector<int>> sample_idx;
492
493
494
495
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
496
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
497
498
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
499
        }
Guolin Ke's avatar
Guolin Ke committed
500
501
502
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
503
504
505
        }
      }
    }
506
    CHECK(num_col >= static_cast<int>(sample_values.size()));
507
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
508
    ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow));
509
  } else {
510
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
511
    ret->CreateValid(
512
      reinterpret_cast<const Dataset*>(reference));
513
514
  }

Guolin Ke's avatar
Guolin Ke committed
515
#pragma omp parallel for schedule(static)
516
517
518
519
520
521
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
522
  *out = ret.release();
523
  API_END();
524
525
}

526
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
527
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
528
529
  const int32_t* indices,
  const void* data,
530
531
532
533
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
534
  const char* parameters,
535
  const DatasetHandle reference,
Guolin Ke's avatar
typo  
Guolin Ke committed
536
  DatasetHandle* out) {
537
  API_BEGIN();
wxchan's avatar
wxchan committed
538
539
540
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
541
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
542
543
544
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
545
546
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
547
548
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
549
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
550
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
551
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
552
553
554
555
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
556
557
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
558
559
        }
      }
Guolin Ke's avatar
Guolin Ke committed
560
    }
561
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
562
    ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
563
  } else {
564
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
565
    ret->CreateValid(
566
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
567
568
  }

Guolin Ke's avatar
Guolin Ke committed
569
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
570
571
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
572
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
573
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
574
575
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
576
577
578
579
580
581
582
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
583
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
584
    }
Guolin Ke's avatar
Guolin Ke committed
585
586
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
587
  *out = ret.release();
588
  API_END();
Guolin Ke's avatar
Guolin Ke committed
589
590
}

591
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
592
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
593
594
595
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
596
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
597
598
599
600
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
601
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
602
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
603
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
604
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
605
606
607
608
  *out = ret.release();
  API_END();
}

609
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
610
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
611
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
612
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
613
614
615
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
616
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
617
618
619
620
621
622
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

623
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
624
625
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
626
  int* num_feature_names) {
627
628
629
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
630
631
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
632
633
634
635
636
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

637
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
638
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
639
  delete reinterpret_cast<Dataset*>(handle);
640
  API_END();
641
642
}

643
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
644
  const char* filename) {
645
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
646
647
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
648
  API_END();
649
650
}

651
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
652
653
  const char* field_name,
  const void* field_data,
Guolin Ke's avatar
Guolin Ke committed
654
  int num_element,
655
  int type) {
656
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
657
  auto dataset = reinterpret_cast<Dataset*>(handle);
658
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
659
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
660
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
661
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
662
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
663
664
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
665
  }
666
667
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
668
669
}

670
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
671
  const char* field_name,
Guolin Ke's avatar
Guolin Ke committed
672
  int* out_len,
673
674
  const void** out_ptr,
  int* out_type) {
675
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
676
  auto dataset = reinterpret_cast<Dataset*>(handle);
677
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
678
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
679
    *out_type = C_API_DTYPE_FLOAT32;
680
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
681
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
682
    *out_type = C_API_DTYPE_INT32;
683
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
684
685
686
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
687
  }
688
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
689
  if (*out_ptr == nullptr) { *out_len = 0; }
690
  API_END();
691
692
}

693
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
694
  int* out) {
695
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
696
697
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
698
  API_END();
699
700
}

701
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
702
  int* out) {
703
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
704
705
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
706
  API_END();
Guolin Ke's avatar
Guolin Ke committed
707
}
708
709
710

// ---- start of booster

711
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
712
713
  const char* parameters,
  BoosterHandle* out) {
714
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
715
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
716
717
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
718
  API_END();
719
720
}

721
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
722
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
723
  int* out_num_iterations,
724
  BoosterHandle* out) {
725
  API_BEGIN();
wxchan's avatar
wxchan committed
726
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
727
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
728
  *out = ret.release();
729
  API_END();
730
731
}

732
733
734
735
736
737
738
739
740
741
742
743
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

744
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
745
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
746
  delete reinterpret_cast<Booster*>(handle);
747
  API_END();
748
749
}

750
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
wxchan's avatar
wxchan committed
751
752
753
754
755
756
757
758
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

759
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
760
  const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
761
762
763
764
765
766
767
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

768
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
769
  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
770
771
772
773
774
775
776
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

777
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
778
779
780
781
782
783
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

784
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
785
786
787
788
789
790
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

791
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
792
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
793
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
794
795
796
797
798
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
799
  API_END();
800
801
}

802
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
803
804
805
  const float* grad,
  const float* hess,
  int* is_finished) {
806
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
807
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
808
809
810
811
812
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
813
  API_END();
814
815
}

816
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
817
818
819
820
821
822
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

823
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
824
825
826
827
828
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
829

830
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
831
832
833
834
835
836
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

837
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
838
839
840
841
842
843
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

wxchan's avatar
wxchan committed
844
845
846
847
848
849
850
851
852
853
854
855
856
857
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

858
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
wxchan's avatar
wxchan committed
859
  int data_idx,
Guolin Ke's avatar
Guolin Ke committed
860
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
861
  double* out_results) {
862
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
863
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
864
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
865
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
866
  *out_len = static_cast<int>(result_buf.size());
867
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
868
    (out_results)[i] = static_cast<double>(result_buf[i]);
869
  }
870
  API_END();
871
872
}

873
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
874
875
876
877
878
879
880
881
  int data_idx,
  int64_t* out_len) {
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

882
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
883
  int data_idx,
884
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
885
  double* out_result) {
886
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
887
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
888
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
889
  API_END();
Guolin Ke's avatar
Guolin Ke committed
890
891
}

892
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
893
  const char* data_filename,
wxchan's avatar
wxchan committed
894
895
  int data_has_header,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
896
  int num_iteration,
897
  const char* result_filename) {
898
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
899
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
900
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
901
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
902
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
903
  API_END();
904
905
}

906
907
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
908
909
910
911
912
913
914
915
916
917
918
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

919
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
920
  int num_row,
Guolin Ke's avatar
Guolin Ke committed
921
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
922
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
923
924
925
926
927
928
929
  int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

930
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
931
932
  const void* indptr,
  int indptr_type,
933
934
  const int32_t* indices,
  const void* data,
935
936
937
938
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
939
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
940
  int num_iteration,
wxchan's avatar
wxchan committed
941
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
942
  double* out_result) {
943
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
944
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
945
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
946
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
947
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
948
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
949
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
950
951
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
952
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
953
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
954
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
955
956
    }
  }
wxchan's avatar
wxchan committed
957
  *out_len = nrow * num_preb_in_one_row;
958
  API_END();
Guolin Ke's avatar
Guolin Ke committed
959
}
960

961
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
962
963
964
965
966
967
968
969
970
  const void* col_ptr,
  int col_ptr_type,
  const int32_t* indices,
  const void* data,
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
971
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
972
973
974
975
976
  int64_t* out_len,
  double* out_result) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
977
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
    [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

1006
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
1007
  const void* data,
1008
  int data_type,
1009
1010
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
1011
  int is_row_major,
1012
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
1013
  int num_iteration,
wxchan's avatar
wxchan committed
1014
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
1015
  double* out_result) {
1016
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1017
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1018
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
1019
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1020
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1021
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1022
1023
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
1024
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
1025
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
1026
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
1027
1028
    }
  }
wxchan's avatar
wxchan committed
1029
  *out_len = nrow * num_preb_in_one_row;
1030
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1031
}
1032

1033
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
1034
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
1035
  const char* filename) {
1036
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1037
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1038
1039
1040
1041
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
  int num_iteration,
  int buffer_len,
  int* out_len,
  char* out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

1057
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
1058
  int num_iteration,
wxchan's avatar
wxchan committed
1059
  int buffer_len,
Guolin Ke's avatar
Guolin Ke committed
1060
  int* out_len,
Guolin Ke's avatar
Guolin Ke committed
1061
  char* out_str) {
wxchan's avatar
wxchan committed
1062
1063
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1064
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1065
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1066
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1067
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1068
  }
1069
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1070
}
1071

1072
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
1073
1074
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
1075
  double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1076
1077
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1078
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1079
1080
1081
  API_END();
}

1082
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
1083
1084
  int tree_idx,
  int leaf_idx,
Guolin Ke's avatar
Guolin Ke committed
1085
  double val) {
Guolin Ke's avatar
Guolin Ke committed
1086
1087
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1088
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1089
1090
1091
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1092
// ---- start of some help functions
1093
1094
1095

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1096
  if (data_type == C_API_DTYPE_FLOAT32) {
1097
1098
1099
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1100
        std::vector<double> ret(num_col);
1101
1102
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1103
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1104
1105
1106
1107
1108
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1109
        std::vector<double> ret(num_col);
1110
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1111
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1112
1113
1114
1115
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1116
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1117
1118
1119
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1120
        std::vector<double> ret(num_col);
1121
1122
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1123
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1124
1125
1126
1127
1128
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1129
        std::vector<double> ret(num_col);
1130
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1131
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1132
1133
1134
1135
1136
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1137
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1138
1139
1140
1141
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1142
1143
1144
1145
1146
1147
1148
1149
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1150
        }
Guolin Ke's avatar
Guolin Ke committed
1151
1152
1153
      }
      return ret;
    };
1154
  }
Guolin Ke's avatar
Guolin Ke committed
1155
  return nullptr;
1156
1157
1158
1159
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1160
  if (data_type == C_API_DTYPE_FLOAT32) {
1161
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1162
    if (indptr_type == C_API_DTYPE_INT32) {
1163
1164
1165
1166
1167
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1168
        for (int64_t i = start; i < end; ++i) {
1169
1170
1171
1172
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1173
    } else if (indptr_type == C_API_DTYPE_INT64) {
1174
1175
1176
1177
1178
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1179
        for (int64_t i = start; i < end; ++i) {
1180
1181
1182
1183
1184
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1185
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1186
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1187
    if (indptr_type == C_API_DTYPE_INT32) {
1188
1189
1190
1191
1192
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1193
        for (int64_t i = start; i < end; ++i) {
1194
1195
1196
1197
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1198
    } else if (indptr_type == C_API_DTYPE_INT64) {
1199
1200
1201
1202
1203
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1204
        for (int64_t i = start; i < end; ++i) {
1205
1206
1207
1208
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1209
1210
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1211
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1212
1213
}

Guolin Ke's avatar
Guolin Ke committed
1214
1215
1216
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1217
  if (data_type == C_API_DTYPE_FLOAT32) {
1218
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1219
    if (col_ptr_type == C_API_DTYPE_INT32) {
1220
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1221
1222
1223
1224
1225
1226
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1227
        }
Guolin Ke's avatar
Guolin Ke committed
1228
1229
1230
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1231
      };
Guolin Ke's avatar
Guolin Ke committed
1232
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1233
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1234
1235
1236
1237
1238
1239
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1240
        }
Guolin Ke's avatar
Guolin Ke committed
1241
1242
1243
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1244
      };
Guolin Ke's avatar
Guolin Ke committed
1245
    }
Guolin Ke's avatar
Guolin Ke committed
1246
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1247
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1248
    if (col_ptr_type == C_API_DTYPE_INT32) {
1249
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1250
1251
1252
1253
1254
1255
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1256
        }
Guolin Ke's avatar
Guolin Ke committed
1257
1258
1259
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1260
      };
Guolin Ke's avatar
Guolin Ke committed
1261
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1262
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1263
1264
1265
1266
1267
1268
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1269
        }
Guolin Ke's avatar
Guolin Ke committed
1270
1271
1272
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1273
      };
Guolin Ke's avatar
Guolin Ke committed
1274
1275
1276
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1277
1278
}

Guolin Ke's avatar
Guolin Ke committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1290
    }
Guolin Ke's avatar
Guolin Ke committed
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1308
    }
Guolin Ke's avatar
Guolin Ke committed
1309
1310
1311
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1312
  }
Guolin Ke's avatar
Guolin Ke committed
1313
}