"python-package/vscode:/vscode.git/clone" did not exist on "0e576575852fa543bea00056a2801b247edc8283"
c_api.cpp 50.3 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
228
229
230
231
232
233
234
235
236
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
237
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
238

Guolin Ke's avatar
Guolin Ke committed
239
private:
240

wxchan's avatar
wxchan committed
241
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
242
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
243
244
245
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
246
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
247
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
248
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
249
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
251
252
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
};

}
Guolin Ke's avatar
Guolin Ke committed
256
257
258

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
269
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
275
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

291
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
292
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
293
294
}

295
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
296
297
298
                                                 const char* parameters,
                                                 const DatasetHandle reference,
                                                 DatasetHandle* out) {
299
  API_BEGIN();
wxchan's avatar
wxchan committed
300
301
302
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
303
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
304
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
305
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
306
  } else {
Guolin Ke's avatar
Guolin Ke committed
307
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
308
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
309
  }
310
  API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
                                                          int** sample_indices,
                                                          int32_t ncol,
                                                          const int* num_per_col,
                                                          int32_t num_sample_row,
                                                          int32_t num_total_row,
                                                          const char* parameters,
                                                          DatasetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr, 1, nullptr);
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
331
332
}

333

Guolin Ke's avatar
Guolin Ke committed
334
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
335
336
                                                    int64_t num_total_row,
                                                    DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
337
338
339
340
341
342
343
344
345
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
346
347
348
349
350
                                           const void* data,
                                           int data_type,
                                           int32_t nrow,
                                           int32_t ncol,
                                           int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
  }
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
367
368
369
370
371
372
373
374
375
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
379
380
381
382
383
384
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
385
                          static_cast<data_size_t>(start_row + i), one_row);
Guolin Ke's avatar
Guolin Ke committed
386
387
388
389
390
391
392
  }
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

393
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
394
395
396
397
398
399
400
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
401
  API_BEGIN();
wxchan's avatar
wxchan committed
402
403
404
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
405
  std::unique_ptr<Dataset> ret;
406
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
407
408
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
409
    Random rand(io_config.data_random_seed);
410
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
411
    auto sample_indices = rand.Sample(nrow, sample_cnt);
412
    sample_cnt = static_cast<int>(sample_indices.size());
413
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
414
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
415
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
416
      auto idx = sample_indices[i];
417
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
418
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
419
420
421
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
422
        }
Guolin Ke's avatar
Guolin Ke committed
423
424
      }
    }
425
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
426
427
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(), 
                                            Common::Vector2Ptr<int>(sample_idx).data(), 
428
429
430
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
431
  } else {
432
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
433
    ret->CreateValid(
434
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
435
436
  }

Guolin Ke's avatar
Guolin Ke committed
437
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
438
439
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
440
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
441
442
443
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
444
  *out = ret.release();
445
  API_END();
446
447
}

448
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
449
450
451
452
453
454
455
456
457
458
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t num_col,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
459
  API_BEGIN();
wxchan's avatar
wxchan committed
460
461
462
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
463
  std::unique_ptr<Dataset> ret;
464
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
465
466
467
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
468
    Random rand(io_config.data_random_seed);
469
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
470
    auto sample_indices = rand.Sample(nrow, sample_cnt);
471
    sample_cnt = static_cast<int>(sample_indices.size());
472
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
473
    std::vector<std::vector<int>> sample_idx;
474
475
476
477
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
478
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
479
480
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
481
        }
Guolin Ke's avatar
Guolin Ke committed
482
483
484
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
485
486
487
        }
      }
    }
488
    CHECK(num_col >= static_cast<int>(sample_values.size()));
489
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
490
491
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
492
493
494
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
495
  } else {
496
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
497
    ret->CreateValid(
498
      reinterpret_cast<const Dataset*>(reference));
499
500
  }

Guolin Ke's avatar
Guolin Ke committed
501
#pragma omp parallel for schedule(static)
502
503
504
505
506
507
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
508
  *out = ret.release();
509
  API_END();
510
511
}

512
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
513
514
515
516
517
518
519
520
521
522
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
523
  API_BEGIN();
wxchan's avatar
wxchan committed
524
525
526
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
527
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
528
529
530
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
531
    Random rand(io_config.data_random_seed);
532
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
533
    auto sample_indices = rand.Sample(nrow, sample_cnt);
534
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
535
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
536
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
537
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
538
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
543
544
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
545
546
        }
      }
Guolin Ke's avatar
Guolin Ke committed
547
    }
548
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
549
550
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
551
552
553
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
554
  } else {
555
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
556
    ret->CreateValid(
557
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
558
559
  }

Guolin Ke's avatar
Guolin Ke committed
560
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
561
562
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
563
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
564
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
565
566
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
567
568
569
570
571
572
573
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
574
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
575
    }
Guolin Ke's avatar
Guolin Ke committed
576
577
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
578
  *out = ret.release();
579
  API_END();
Guolin Ke's avatar
Guolin Ke committed
580
581
}

582
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
583
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
584
585
586
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
587
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
588
589
590
591
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
592
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
593
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
594
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
595
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
596
597
598
599
  *out = ret.release();
  API_END();
}

600
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
601
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
602
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
603
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
604
605
606
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
607
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
608
609
610
611
612
613
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

614
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
615
616
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
617
  int* num_feature_names) {
618
619
620
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
621
622
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
623
624
625
626
627
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

628
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
629
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
630
  delete reinterpret_cast<Dataset*>(handle);
631
  API_END();
632
633
}

634
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
635
                                             const char* filename) {
636
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
637
638
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
639
  API_END();
640
641
}

642
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
643
644
645
646
                                           const char* field_name,
                                           const void* field_data,
                                           int num_element,
                                           int type) {
647
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
648
  auto dataset = reinterpret_cast<Dataset*>(handle);
649
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
650
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
651
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
652
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
653
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
654
655
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
656
  }
657
658
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
659
660
}

661
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
662
663
664
665
                                           const char* field_name,
                                           int* out_len,
                                           const void** out_ptr,
                                           int* out_type) {
666
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
667
  auto dataset = reinterpret_cast<Dataset*>(handle);
668
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
669
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
670
    *out_type = C_API_DTYPE_FLOAT32;
671
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
672
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
673
    *out_type = C_API_DTYPE_INT32;
674
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
675
676
677
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
678
  }
679
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
680
  if (*out_ptr == nullptr) { *out_len = 0; }
681
  API_END();
682
683
}

684
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
685
                                             int* out) {
686
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
687
688
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
689
  API_END();
690
691
}

692
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
693
                                                int* out) {
694
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
695
696
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
697
  API_END();
Guolin Ke's avatar
Guolin Ke committed
698
}
699
700
701

// ---- start of booster

702
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
703
704
                                         const char* parameters,
                                         BoosterHandle* out) {
705
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
706
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
707
708
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
709
  API_END();
710
711
}

712
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
713
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
714
  int* out_num_iterations,
715
  BoosterHandle* out) {
716
  API_BEGIN();
wxchan's avatar
wxchan committed
717
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
718
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
719
  *out = ret.release();
720
  API_END();
721
722
}

723
724
725
726
727
728
729
730
731
732
733
734
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

735
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
736
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
737
  delete reinterpret_cast<Booster*>(handle);
738
  API_END();
739
740
}

741
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
742
                                        BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
743
744
745
746
747
748
749
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

750
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
751
                                               const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
752
753
754
755
756
757
758
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

759
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
760
                                                    const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
761
762
763
764
765
766
767
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

768
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
769
770
771
772
773
774
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

775
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
776
777
778
779
780
781
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

782
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
783
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
784
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
785
786
787
788
789
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
790
  API_END();
791
792
}

793
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
794
795
796
                                                      const float* grad,
                                                      const float* hess,
                                                      int* is_finished) {
797
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
798
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
799
800
801
802
803
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
804
  API_END();
805
806
}

807
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
808
809
810
811
812
813
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

814
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
815
816
817
818
819
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
820

821
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
822
823
824
825
826
827
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

828
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
829
830
831
832
833
834
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

wxchan's avatar
wxchan committed
835
836
837
838
839
840
841
842
843
844
845
846
847
848
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

849
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
850
851
852
                                          int data_idx,
                                          int* out_len,
                                          double* out_results) {
853
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
854
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
855
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
856
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
857
  *out_len = static_cast<int>(result_buf.size());
858
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
859
    (out_results)[i] = static_cast<double>(result_buf[i]);
860
  }
861
  API_END();
862
863
}

864
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
865
866
                                                int data_idx,
                                                int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
867
868
869
870
871
872
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

873
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
874
875
876
                                             int data_idx,
                                             int64_t* out_len,
                                             double* out_result) {
877
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
878
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
879
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
880
  API_END();
Guolin Ke's avatar
Guolin Ke committed
881
882
}

883
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
884
885
886
887
888
                                                 const char* data_filename,
                                                 int data_has_header,
                                                 int predict_type,
                                                 int num_iteration,
                                                 const char* result_filename) {
889
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
890
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
891
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
892
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
893
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
894
  API_END();
895
896
}

897
898
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
899
900
901
902
903
904
905
906
907
908
909
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

910
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
911
912
913
914
                                                 int num_row,
                                                 int predict_type,
                                                 int num_iteration,
                                                 int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
915
916
917
918
919
920
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

921
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
922
923
924
925
926
927
928
929
930
931
932
933
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
934
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
935
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
936
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
937
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
938
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
939
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
940
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
941
942
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
943
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
944
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
945
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
946
947
    }
  }
wxchan's avatar
wxchan committed
948
  *out_len = nrow * num_preb_in_one_row;
949
  API_END();
Guolin Ke's avatar
Guolin Ke committed
950
}
951

952
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
953
954
955
956
957
958
959
960
961
962
963
964
                                                const void* col_ptr,
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
965
966
967
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
968
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
969
970
971
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
972
                          [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
Guolin Ke's avatar
Guolin Ke committed
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

997
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
998
999
1000
1001
1002
1003
1004
1005
1006
                                                const void* data,
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
1007
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1008
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1009
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
1010
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1011
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1012
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1013
1014
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
1015
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
1016
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
1017
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
1018
1019
    }
  }
wxchan's avatar
wxchan committed
1020
  *out_len = nrow * num_preb_in_one_row;
1021
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1022
}
1023

1024
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
1025
1026
                                            int num_iteration,
                                            const char* filename) {
1027
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1028
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1029
1030
1031
1032
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

1033
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1034
1035
1036
1037
                                                    int num_iteration,
                                                    int buffer_len,
                                                    int* out_len,
                                                    char* out_str) {
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

1048
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
1049
1050
1051
1052
                                            int num_iteration,
                                            int buffer_len,
                                            int* out_len,
                                            char* out_str) {
wxchan's avatar
wxchan committed
1053
1054
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1055
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1056
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1057
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1058
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1059
  }
1060
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1061
}
1062

1063
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1064
1065
1066
                                               int tree_idx,
                                               int leaf_idx,
                                               double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1067
1068
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1069
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1070
1071
1072
  API_END();
}

1073
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1074
1075
1076
                                               int tree_idx,
                                               int leaf_idx,
                                               double val) {
Guolin Ke's avatar
Guolin Ke committed
1077
1078
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1079
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1080
1081
1082
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1083
// ---- start of some help functions
1084
1085
1086

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1087
  if (data_type == C_API_DTYPE_FLOAT32) {
1088
1089
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1090
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1091
        std::vector<double> ret(num_col);
1092
1093
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1094
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1095
1096
1097
1098
        }
        return ret;
      };
    } else {
1099
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1100
        std::vector<double> ret(num_col);
1101
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1102
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1103
1104
1105
1106
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1107
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1108
1109
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1110
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1111
        std::vector<double> ret(num_col);
1112
1113
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1114
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1115
1116
1117
1118
        }
        return ret;
      };
    } else {
1119
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1120
        std::vector<double> ret(num_col);
1121
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1122
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1123
1124
1125
1126
1127
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1128
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1129
1130
1131
1132
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1133
1134
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1135
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1136
1137
1138
1139
1140
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1141
        }
Guolin Ke's avatar
Guolin Ke committed
1142
1143
1144
      }
      return ret;
    };
1145
  }
Guolin Ke's avatar
Guolin Ke committed
1146
  return nullptr;
1147
1148
1149
1150
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1151
  if (data_type == C_API_DTYPE_FLOAT32) {
1152
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1153
    if (indptr_type == C_API_DTYPE_INT32) {
1154
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1155
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1156
1157
1158
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1159
        for (int64_t i = start; i < end; ++i) {
1160
1161
1162
1163
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1164
    } else if (indptr_type == C_API_DTYPE_INT64) {
1165
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1166
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1167
1168
1169
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1170
        for (int64_t i = start; i < end; ++i) {
1171
1172
1173
1174
1175
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1176
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1177
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1178
    if (indptr_type == C_API_DTYPE_INT32) {
1179
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1180
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1181
1182
1183
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1184
        for (int64_t i = start; i < end; ++i) {
1185
1186
1187
1188
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1189
    } else if (indptr_type == C_API_DTYPE_INT64) {
1190
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1191
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1192
1193
1194
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1195
        for (int64_t i = start; i < end; ++i) {
1196
1197
1198
1199
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1200
1201
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1202
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1203
1204
}

Guolin Ke's avatar
Guolin Ke committed
1205
1206
1207
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1208
  if (data_type == C_API_DTYPE_FLOAT32) {
1209
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1210
    if (col_ptr_type == C_API_DTYPE_INT32) {
1211
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1212
1213
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1214
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1215
1216
1217
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1218
        }
Guolin Ke's avatar
Guolin Ke committed
1219
1220
1221
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1222
      };
Guolin Ke's avatar
Guolin Ke committed
1223
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1224
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1225
1226
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1227
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1228
1229
1230
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1231
        }
Guolin Ke's avatar
Guolin Ke committed
1232
1233
1234
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1235
      };
Guolin Ke's avatar
Guolin Ke committed
1236
    }
Guolin Ke's avatar
Guolin Ke committed
1237
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1238
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1239
    if (col_ptr_type == C_API_DTYPE_INT32) {
1240
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1241
1242
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1243
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1244
1245
1246
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1247
        }
Guolin Ke's avatar
Guolin Ke committed
1248
1249
1250
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1251
      };
Guolin Ke's avatar
Guolin Ke committed
1252
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1253
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1254
1255
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1256
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1257
1258
1259
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1260
        }
Guolin Ke's avatar
Guolin Ke committed
1261
1262
1263
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1264
      };
Guolin Ke's avatar
Guolin Ke committed
1265
1266
1267
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1268
1269
}

Guolin Ke's avatar
Guolin Ke committed
1270
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1271
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1272
1273
1274
1275
1276
1277
1278
1279
1280
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1281
    }
Guolin Ke's avatar
Guolin Ke committed
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1299
    }
Guolin Ke's avatar
Guolin Ke committed
1300
1301
1302
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1303
  }
Guolin Ke's avatar
Guolin Ke committed
1304
}