c_api.cpp 64 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

Guolin Ke's avatar
Guolin Ke committed
49
class Booster {
Nikita Titov's avatar
Nikita Titov committed
50
 public:
Guolin Ke's avatar
Guolin Ke committed
51
  explicit Booster(const char* filename) {
52
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
53
54
  }

Guolin Ke's avatar
Guolin Ke committed
55
  Booster(const Dataset* train_data,
56
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
57
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
58
    config_.Set(param);
59
60
61
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
62
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
63
    if (config_.input_model.size() > 0) {
64
65
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
66
    }
Guolin Ke's avatar
Guolin Ke committed
67

Guolin Ke's avatar
Guolin Ke committed
68
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
69

70
71
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
72
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
73
    if (config_.tree_learner == std::string("feature")) {
74
      Log::Fatal("Do not support feature parallel in c api");
75
    }
Guolin Ke's avatar
Guolin Ke committed
76
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
77
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
78
      config_.tree_learner = "serial";
79
    }
Guolin Ke's avatar
Guolin Ke committed
80
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
81
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
82
83
84
85
86
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
  }

  ~Booster() {
  }
91

92
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
93
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
94
95
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
100
101
102
103
104
105
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
106
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
107
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
108
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
114
115
116
117
118
119
120
121
122
123
124
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
125
126
127
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
128
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
129
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
130
    if (param.count("num_class")) {
131
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
132
    }
Guolin Ke's avatar
Guolin Ke committed
133
134
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
135
    }
Guolin Ke's avatar
Guolin Ke committed
136
    if (param.count("metric")) {
137
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
138
    }
Guolin Ke's avatar
Guolin Ke committed
139
140

    config_.Set(param);
141
142
143
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
144
145
146

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
147
148
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
149
150
151
152
153
154
155
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
156
157
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
158
    }
Guolin Ke's avatar
Guolin Ke committed
159

Guolin Ke's avatar
Guolin Ke committed
160
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
161
162
163
164
165
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
166
167
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
168
169
170
171
172
173
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
174
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
175
  }
Guolin Ke's avatar
Guolin Ke committed
176

177
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
178
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
179
    return boosting_->TrainOneIter(nullptr, nullptr);
180
181
  }

Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
188
189
190
191
192
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

193
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
194
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
195
    return boosting_->TrainOneIter(gradients, hessians);
196
197
  }

wxchan's avatar
wxchan committed
198
199
200
201
202
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  void PredictSingleRow(int num_iteration, int predict_type,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);

    if (single_row_predictor_.get() == nullptr) {
      bool is_predict_leaf = false;
      bool is_raw_score = false;
      bool predict_contrib = false;
      if (predict_type == C_API_PREDICT_LEAF_INDEX) {
        is_predict_leaf = true;
      } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
        is_raw_score = true;
      } else if (predict_type == C_API_PREDICT_CONTRIB) {
        predict_contrib = true;
      } else {
        is_raw_score = false;
      }

223
      // TODO(eisber): config could be optimized away... (maybe using lambda callback?)
224
      single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
225
                                                config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
226
227
228
229
230
231
232
233
234
235
236
237
      single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
      single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
    }

    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
    single_row_predict_function_(one_row, pred_wrt_ptr);

    *out_len = single_row_num_pred_in_one_row_;
  }


Guolin Ke's avatar
Guolin Ke committed
238
239
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
240
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
241
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
242
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
243
244
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
245
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
246
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
247
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
248
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
249
      is_raw_score = true;
250
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
251
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
252
253
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
254
    }
Guolin Ke's avatar
Guolin Ke committed
255

Guolin Ke's avatar
Guolin Ke committed
256
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
257
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
258
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
259
    auto pred_fun = predictor.GetPredictFunction();
260
261
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
262
    for (int i = 0; i < nrow; ++i) {
263
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
264
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
265
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
266
      pred_fun(one_row, pred_wrt_ptr);
267
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
268
    }
269
    OMP_THROW_EX();
270
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
271
272
273
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
274
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
275
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
279
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
284
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
285
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
286
287
288
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
289
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
290
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
291
292
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
293
294
  }

Guolin Ke's avatar
Guolin Ke committed
295
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
296
297
298
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

299
300
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
301
  }
302

303
  void LoadModelFromString(const char* model_str) {
304
305
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
306
307
  }

308
309
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
310
311
  }

312
  std::string DumpModel(int start_iteration, int num_iteration) {
313
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
314
  }
315

316
317
318
319
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
320
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
321
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
326
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
327
328
  }

329
  void ShuffleModels(int start_iter, int end_iter) {
330
    std::lock_guard<std::mutex> lock(mutex_);
331
    boosting_->ShuffleModels(start_iter, end_iter);
332
333
  }

wxchan's avatar
wxchan committed
334
335
336
337
338
339
340
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
341

wxchan's avatar
wxchan committed
342
343
344
345
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
346
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
347
348
349
350
351
352
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
353
354
355
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
356
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
357
358
359
360
361
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
362
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
363

Nikita Titov's avatar
Nikita Titov committed
364
 private:
wxchan's avatar
wxchan committed
365
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
366
  std::unique_ptr<Boosting> boosting_;
367
368
369
370
  std::unique_ptr<Predictor> single_row_predictor_;
  PredictFunction single_row_predict_function_;
  int64_t single_row_num_pred_in_one_row_;

Guolin Ke's avatar
Guolin Ke committed
371
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
372
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
373
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
374
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
375
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
376
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
377
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
378
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
379
380
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
381
382
};

383
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
384
385
386

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
391
392
393
394
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

395
396
397
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
398
399
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
400
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
401
402
403

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
404
 public:
Guolin Ke's avatar
Guolin Ke committed
405
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
406
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
407
408
409
410
411
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
412
413

 private:
Guolin Ke's avatar
Guolin Ke committed
414
415
416
417
418
419
420
421
422
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
423
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
424
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
425
426
}

Guolin Ke's avatar
Guolin Ke committed
427
int LGBM_DatasetCreateFromFile(const char* filename,
428
429
430
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
431
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
432
433
  auto param = Config::Str2Map(parameters);
  Config config;
434
435
436
437
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
438
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
439
  if (reference == nullptr) {
440
441
442
443
444
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
445
  } else {
446
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
447
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
448
  }
449
  API_END();
Guolin Ke's avatar
Guolin Ke committed
450
451
}

452

Guolin Ke's avatar
Guolin Ke committed
453
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
454
455
456
457
458
459
460
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
461
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
462
463
  auto param = Config::Str2Map(parameters);
  Config config;
464
465
466
467
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
468
  DatasetLoader loader(config, nullptr, 1, nullptr);
469
470
471
472
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
473
474
}

475

Guolin Ke's avatar
Guolin Ke committed
476
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
477
478
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
479
480
481
482
483
484
485
486
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
487
int LGBM_DatasetPushRows(DatasetHandle dataset,
488
489
490
491
492
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
493
494
495
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
496
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
497
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
498
  for (int i = 0; i < nrow; ++i) {
499
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
500
501
502
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
503
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
504
  }
505
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
510
511
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
512
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
513
514
515
516
517
518
519
520
521
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
522
523
524
525
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
526
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
527
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
528
  for (int i = 0; i < nrow; ++i) {
529
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
530
531
532
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
533
                          static_cast<data_size_t>(start_row + i), one_row);
534
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
535
  }
536
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
537
538
539
540
541
542
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
543
int LGBM_DatasetCreateFromMat(const void* data,
544
545
546
547
548
549
550
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
572
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
573
574
  auto param = Config::Str2Map(parameters);
  Config config;
575
576
577
578
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
579
  std::unique_ptr<Dataset> ret;
580
581
582
583
584
585
586
587
588
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
589

Guolin Ke's avatar
Guolin Ke committed
590
591
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
592
    Random rand(config.data_random_seed);
593
594
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
595
    sample_cnt = static_cast<int>(sample_indices.size());
596
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
597
    std::vector<std::vector<int>> sample_idx(ncol);
598
599
600

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
601
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
602
      auto idx = sample_indices[i];
603
604
605
606
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
607

608
609
610
611
612
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
613
        }
Guolin Ke's avatar
Guolin Ke committed
614
615
      }
    }
Guolin Ke's avatar
Guolin Ke committed
616
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
617
618
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
619
620
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
621
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
622
  } else {
623
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
624
    ret->CreateValid(
625
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
626
  }
627
628
629
630
631
632
633
634
635
636
637
638
639
640
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
641
642
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
643
  *out = ret.release();
644
  API_END();
645
646
}

Guolin Ke's avatar
Guolin Ke committed
647
int LGBM_DatasetCreateFromCSR(const void* indptr,
648
649
650
651
652
653
654
655
656
657
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
658
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
659
660
  auto param = Config::Str2Map(parameters);
  Config config;
661
662
663
664
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
665
  std::unique_ptr<Dataset> ret;
666
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
667
668
669
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
670
671
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
672
    auto sample_indices = rand.Sample(nrow, sample_cnt);
673
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
674
675
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
676
677
678
679
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
680
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
681
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
682
683
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
684
685
686
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
687
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
688
689
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
690
691
692
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
693
  } else {
694
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
695
    ret->CreateValid(
696
      reinterpret_cast<const Dataset*>(reference));
697
  }
698
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
699
  #pragma omp parallel for schedule(static)
700
  for (int i = 0; i < nindptr - 1; ++i) {
701
    OMP_LOOP_EX_BEGIN();
702
703
704
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
705
    OMP_LOOP_EX_END();
706
  }
707
  OMP_THROW_EX();
708
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
709
  *out = ret.release();
710
  API_END();
711
712
}

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
                              int num_rows,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
  API_BEGIN();

  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
763

764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
            const int tid = omp_get_thread_num();
            get_row_fun(i, threadBuffer);

            ret->PushOneRow(tid, i, threadBuffer);
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
783
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
784
785
786
787
788
789
790
791
792
793
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
794
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
795
796
  auto param = Config::Str2Map(parameters);
  Config config;
797
798
799
800
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
801
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
802
803
804
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
805
806
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
807
    auto sample_indices = rand.Sample(nrow, sample_cnt);
808
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
809
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
810
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
811
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
812
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
813
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
814
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
815
816
817
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
818
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
819
820
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
821
822
        }
      }
823
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
824
    }
825
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
826
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
827
828
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
829
830
831
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
832
  } else {
833
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
834
    ret->CreateValid(
835
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
836
  }
837
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
838
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
839
  for (int i = 0; i < ncol_ptr - 1; ++i) {
840
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
841
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
842
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
843
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
844
845
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
846
847
848
849
850
851
852
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
853
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
854
    }
855
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
856
  }
857
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
858
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
859
  *out = ret.release();
860
  API_END();
Guolin Ke's avatar
Guolin Ke committed
861
862
}

Guolin Ke's avatar
Guolin Ke committed
863
int LGBM_DatasetGetSubset(
864
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
865
866
867
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
868
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
869
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
870
871
  auto param = Config::Str2Map(parameters);
  Config config;
872
873
874
875
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
876
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
877
  CHECK(num_used_row_indices > 0);
878
879
880
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
881
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
882
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
883
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
884
885
886
887
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
888
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
889
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
890
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
891
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
892
893
894
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
895
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
896
897
898
899
900
901
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
902
int LGBM_DatasetGetFeatureNames(
903
904
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
905
  int* num_feature_names) {
906
907
908
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
909
910
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
911
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
912
913
914
915
  }
  API_END();
}

916
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
917
int LGBM_DatasetFree(DatasetHandle handle) {
918
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
919
  delete reinterpret_cast<Dataset*>(handle);
920
  API_END();
921
922
}

Guolin Ke's avatar
Guolin Ke committed
923
int LGBM_DatasetSaveBinary(DatasetHandle handle,
924
                           const char* filename) {
925
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
926
927
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
928
  API_END();
929
930
}

931
932
933
934
935
936
937
938
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
939
int LGBM_DatasetSetField(DatasetHandle handle,
940
941
942
943
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
944
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
945
  auto dataset = reinterpret_cast<Dataset*>(handle);
946
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
947
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
948
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
949
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
950
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
951
952
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
953
  }
954
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
955
  API_END();
956
957
}

Guolin Ke's avatar
Guolin Ke committed
958
int LGBM_DatasetGetField(DatasetHandle handle,
959
960
961
962
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
963
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
964
  auto dataset = reinterpret_cast<Dataset*>(handle);
965
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
966
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
967
    *out_type = C_API_DTYPE_FLOAT32;
968
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
969
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
970
    *out_type = C_API_DTYPE_INT32;
971
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
972
973
974
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
975
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
976
977
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
978
  }
979
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
980
  if (*out_ptr == nullptr) { *out_len = 0; }
981
  API_END();
982
983
}

984
985
986
987
988
989
990
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
991
int LGBM_DatasetGetNumData(DatasetHandle handle,
992
                           int* out) {
993
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
994
995
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
996
  API_END();
997
998
}

Guolin Ke's avatar
Guolin Ke committed
999
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1000
                              int* out) {
1001
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1002
1003
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1004
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1005
}
1006

1007
1008
1009
1010
1011
1012
1013
1014
1015
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1016
1017
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1018
int LGBM_BoosterCreate(const DatasetHandle train_data,
1019
1020
                       const char* parameters,
                       BoosterHandle* out) {
1021
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1022
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1023
1024
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1025
  API_END();
1026
1027
}

Guolin Ke's avatar
Guolin Ke committed
1028
int LGBM_BoosterCreateFromModelfile(
1029
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1030
  int* out_num_iterations,
1031
  BoosterHandle* out) {
1032
  API_BEGIN();
wxchan's avatar
wxchan committed
1033
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1034
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1035
  *out = ret.release();
1036
  API_END();
1037
1038
}

Guolin Ke's avatar
Guolin Ke committed
1039
int LGBM_BoosterLoadModelFromString(
1040
1041
1042
1043
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1044
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1045
1046
1047
1048
1049
1050
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1051
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1052
int LGBM_BoosterFree(BoosterHandle handle) {
1053
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1054
  delete reinterpret_cast<Booster*>(handle);
1055
  API_END();
1056
1057
}

1058
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1059
1060
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1061
  ref_booster->ShuffleModels(start_iter, end_iter);
1062
1063
1064
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1065
int LGBM_BoosterMerge(BoosterHandle handle,
1066
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1067
1068
1069
1070
1071
1072
1073
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1074
int LGBM_BoosterAddValidData(BoosterHandle handle,
1075
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1076
1077
1078
1079
1080
1081
1082
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1083
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1084
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1085
1086
1087
1088
1089
1090
1091
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1092
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1093
1094
1095
1096
1097
1098
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1099
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1100
1101
1102
1103
1104
1105
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1106
1107
1108
1109
1110
1111
1112
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1113
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1114
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1115
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1116
1117
1118
1119
1120
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1121
  API_END();
1122
1123
}

Guolin Ke's avatar
Guolin Ke committed
1124
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1125
1126
1127
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1128
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1129
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1130
  #ifdef SCORE_T_USE_DOUBLE
1131
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1132
  #else
1133
1134
1135
1136
1137
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1138
  #endif
1139
  API_END();
1140
1141
}

Guolin Ke's avatar
Guolin Ke committed
1142
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1143
1144
1145
1146
1147
1148
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1149
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1150
1151
1152
1153
1154
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1155

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1170
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1171
1172
1173
1174
1175
1176
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1177
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1178
1179
1180
1181
1182
1183
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1184
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1185
1186
1187
1188
1189
1190
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1191
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1192
1193
1194
1195
1196
1197
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1198
int LGBM_BoosterGetEval(BoosterHandle handle,
1199
1200
1201
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1202
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1203
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1204
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1205
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1206
  *out_len = static_cast<int>(result_buf.size());
1207
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1208
    (out_results)[i] = static_cast<double>(result_buf[i]);
1209
  }
1210
  API_END();
1211
1212
}

Guolin Ke's avatar
Guolin Ke committed
1213
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1214
1215
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1216
1217
1218
1219
1220
1221
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1222
int LGBM_BoosterGetPredict(BoosterHandle handle,
1223
1224
1225
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1226
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1227
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1228
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1229
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1230
1231
}

Guolin Ke's avatar
Guolin Ke committed
1232
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1233
1234
1235
1236
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1237
                               const char* parameter,
1238
                               const char* result_filename) {
1239
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1240
1241
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1242
1243
1244
1245
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1246
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1247
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1248
                       config, result_filename);
1249
  API_END();
1250
1251
}

Guolin Ke's avatar
Guolin Ke committed
1252
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1253
1254
1255
1256
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1257
1258
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1259
1260
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1261
1262
1263
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1264
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1275
                              const char* parameter,
1276
1277
                              int64_t* out_len,
                              double* out_result) {
1278
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1279
1280
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1281
1282
1283
1284
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1285
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1286
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1287
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1288
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1289
                       config, out_result, out_len);
1290
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1291
}
1292

1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1322
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1333
                              const char* parameter,
1334
1335
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1336
1337
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1338
1339
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1350
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1351
1352
1353
1354
1355
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1356
1357
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1358
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1359
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1360
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1361
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1362
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1363
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1364
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1365
1366
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1367
1368
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1369
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1370
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1371
1372
1373
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1374
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1375
1376
1377
1378
1379
1380
1381
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1382
                              const char* parameter,
1383
1384
                              int64_t* out_len,
                              double* out_result) {
1385
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1386
1387
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1388
1389
1390
1391
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1392
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1393
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1394
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1395
                       config, out_result, out_len);
1396
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1397
}
1398

1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
                              const void* data,
                              int data_type,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
                       config, out_result, out_len);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1448
int LGBM_BoosterSaveModel(BoosterHandle handle,
1449
                          int start_iteration,
1450
1451
                          int num_iteration,
                          const char* filename) {
1452
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1453
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1454
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1455
1456
1457
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1458
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1459
                                  int start_iteration,
1460
                                  int num_iteration,
1461
                                  int64_t buffer_len,
1462
                                  int64_t* out_len,
1463
                                  char* out_str) {
1464
1465
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1466
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1467
  *out_len = static_cast<int64_t>(model.size()) + 1;
1468
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1469
    std::memcpy(out_str, model.c_str(), *out_len);
1470
1471
1472
1473
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1474
int LGBM_BoosterDumpModel(BoosterHandle handle,
1475
                          int start_iteration,
1476
                          int num_iteration,
1477
1478
                          int64_t buffer_len,
                          int64_t* out_len,
1479
                          char* out_str) {
wxchan's avatar
wxchan committed
1480
1481
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1482
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1483
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1484
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1485
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1486
  }
1487
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1488
}
1489

Guolin Ke's avatar
Guolin Ke committed
1490
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1491
1492
1493
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1494
1495
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1496
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1497
1498
1499
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1500
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1501
1502
1503
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1504
1505
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1506
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1507
1508
1509
  API_END();
}

1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1523
1524
1525
1526
1527
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1528
  Config config;
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1545
1546
1547
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1548
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1549
  if (num_machines > 1) {
1550
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1551
1552
1553
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1554

Guolin Ke's avatar
Guolin Ke committed
1555
// ---- start of some help functions
1556
1557
1558

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1559
  if (data_type == C_API_DTYPE_FLOAT32) {
1560
1561
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1562
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1563
        std::vector<double> ret(num_col);
1564
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1565
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1566
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1567
1568
1569
1570
        }
        return ret;
      };
    } else {
1571
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1572
        std::vector<double> ret(num_col);
1573
        for (int i = 0; i < num_col; ++i) {
1574
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1575
1576
1577
1578
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1579
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1580
1581
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1582
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1583
        std::vector<double> ret(num_col);
1584
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1585
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1586
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1587
1588
1589
1590
        }
        return ret;
      };
    } else {
1591
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1592
        std::vector<double> ret(num_col);
1593
        for (int i = 0; i < num_col; ++i) {
1594
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1595
1596
1597
1598
1599
        }
        return ret;
      };
    }
  }
1600
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1601
1602
1603
1604
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1605
1606
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1607
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1608
1609
1610
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1611
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1612
          ret.emplace_back(i, raw_values[i]);
1613
        }
Guolin Ke's avatar
Guolin Ke committed
1614
1615
1616
      }
      return ret;
    };
1617
  }
Guolin Ke's avatar
Guolin Ke committed
1618
  return nullptr;
1619
1620
}

1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1637
std::function<std::vector<std::pair<int, double>>(int idx)>
1638
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1639
  if (data_type == C_API_DTYPE_FLOAT32) {
1640
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1641
    if (indptr_type == C_API_DTYPE_INT32) {
1642
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1643
      return [=] (int idx) {
1644
1645
1646
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1647
1648
1649
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1650
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1651
          ret.emplace_back(indices[i], data_ptr[i]);
1652
1653
1654
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1655
    } else if (indptr_type == C_API_DTYPE_INT64) {
1656
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1657
      return [=] (int idx) {
1658
1659
1660
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1661
1662
1663
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1664
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1665
          ret.emplace_back(indices[i], data_ptr[i]);
1666
1667
1668
1669
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1670
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1671
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1672
    if (indptr_type == C_API_DTYPE_INT32) {
1673
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1674
      return [=] (int idx) {
1675
1676
1677
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1678
1679
1680
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1681
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1682
          ret.emplace_back(indices[i], data_ptr[i]);
1683
1684
1685
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1686
    } else if (indptr_type == C_API_DTYPE_INT64) {
1687
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1688
      return [=] (int idx) {
1689
1690
1691
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1692
1693
1694
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1695
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1696
          ret.emplace_back(indices[i], data_ptr[i]);
1697
1698
1699
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1700
1701
    }
  }
1702
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1703
1704
}

Guolin Ke's avatar
Guolin Ke committed
1705
std::function<std::pair<int, double>(int idx)>
1706
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1707
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1708
  if (data_type == C_API_DTYPE_FLOAT32) {
1709
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1710
    if (col_ptr_type == C_API_DTYPE_INT32) {
1711
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1712
1713
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1714
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1715
1716
1717
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1718
        }
Guolin Ke's avatar
Guolin Ke committed
1719
1720
1721
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1722
      };
Guolin Ke's avatar
Guolin Ke committed
1723
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1724
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1725
1726
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1727
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1728
1729
1730
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1731
        }
Guolin Ke's avatar
Guolin Ke committed
1732
1733
1734
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1735
      };
Guolin Ke's avatar
Guolin Ke committed
1736
    }
Guolin Ke's avatar
Guolin Ke committed
1737
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1738
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1739
    if (col_ptr_type == C_API_DTYPE_INT32) {
1740
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1741
1742
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1743
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1744
1745
1746
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1747
        }
Guolin Ke's avatar
Guolin Ke committed
1748
1749
1750
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1751
      };
Guolin Ke's avatar
Guolin Ke committed
1752
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1753
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1754
1755
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1756
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1757
1758
1759
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1760
        }
Guolin Ke's avatar
Guolin Ke committed
1761
1762
1763
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1764
      };
Guolin Ke's avatar
Guolin Ke committed
1765
1766
    }
  }
1767
  throw std::runtime_error("Unknown data type in CSC matrix");
1768
1769
}

Guolin Ke's avatar
Guolin Ke committed
1770
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1771
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1772
1773
1774
1775
1776
1777
1778
1779
1780
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1781
    }
Guolin Ke's avatar
Guolin Ke committed
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1799
    }
Guolin Ke's avatar
Guolin Ke committed
1800
1801
1802
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1803
  }
Guolin Ke's avatar
Guolin Ke committed
1804
}