"python-package/vscode:/vscode.git/clone" did not exist on "f53116af34b0080c83f8becd3dc79dc41b06d11b"
c_api.cpp 64.2 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

Guolin Ke's avatar
Guolin Ke committed
49
class Booster {
Nikita Titov's avatar
Nikita Titov committed
50
 public:
Guolin Ke's avatar
Guolin Ke committed
51
  explicit Booster(const char* filename) {
52
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
53
54
  }

Guolin Ke's avatar
Guolin Ke committed
55
  Booster(const Dataset* train_data,
56
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
57
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
58
    config_.Set(param);
59
60
61
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
62
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
63
    if (config_.input_model.size() > 0) {
64
65
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
66
    }
Guolin Ke's avatar
Guolin Ke committed
67

Guolin Ke's avatar
Guolin Ke committed
68
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
69

70
71
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
72
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
73
    if (config_.tree_learner == std::string("feature")) {
74
      Log::Fatal("Do not support feature parallel in c api");
75
    }
Guolin Ke's avatar
Guolin Ke committed
76
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
77
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
78
      config_.tree_learner = "serial";
79
    }
Guolin Ke's avatar
Guolin Ke committed
80
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
81
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
82
83
84
85
86
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
  }

  ~Booster() {
  }
91

92
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
93
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
94
95
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
100
101
102
103
104
105
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
106
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
107
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
108
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
114
115
116
117
118
119
120
121
122
123
124
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
125
126
127
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
128
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
129
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
130
    if (param.count("num_class")) {
131
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
132
    }
Guolin Ke's avatar
Guolin Ke committed
133
134
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
135
    }
Guolin Ke's avatar
Guolin Ke committed
136
    if (param.count("metric")) {
137
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
138
    }
Guolin Ke's avatar
Guolin Ke committed
139
140

    config_.Set(param);
141
142
143
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
144
145
146

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
147
148
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
149
150
151
152
153
154
155
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
156
157
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
158
    }
Guolin Ke's avatar
Guolin Ke committed
159

Guolin Ke's avatar
Guolin Ke committed
160
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
161
162
163
164
165
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
166
167
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
168
169
170
171
172
173
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
174
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
175
  }
Guolin Ke's avatar
Guolin Ke committed
176

177
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
178
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
179
    return boosting_->TrainOneIter(nullptr, nullptr);
180
181
  }

Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
188
189
190
191
192
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

193
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
194
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
195
    return boosting_->TrainOneIter(gradients, hessians);
196
197
  }

wxchan's avatar
wxchan committed
198
199
200
201
202
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  void PredictSingleRow(int num_iteration, int predict_type,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);

    if (single_row_predictor_.get() == nullptr) {
      bool is_predict_leaf = false;
      bool is_raw_score = false;
      bool predict_contrib = false;
      if (predict_type == C_API_PREDICT_LEAF_INDEX) {
        is_predict_leaf = true;
      } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
        is_raw_score = true;
      } else if (predict_type == C_API_PREDICT_CONTRIB) {
        predict_contrib = true;
      } else {
        is_raw_score = false;
      }

223
      // TODO(eisber): config could be optimized away... (maybe using lambda callback?)
224
      single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
225
                                                config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
226
227
228
229
230
231
232
233
234
235
236
237
      single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
      single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
    }

    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
    single_row_predict_function_(one_row, pred_wrt_ptr);

    *out_len = single_row_num_pred_in_one_row_;
  }


Guolin Ke's avatar
Guolin Ke committed
238
239
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
240
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
241
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
242
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
243
244
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
245
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
246
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
247
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
248
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
249
      is_raw_score = true;
250
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
251
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
252
253
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
254
    }
Guolin Ke's avatar
Guolin Ke committed
255

Guolin Ke's avatar
Guolin Ke committed
256
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
257
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
258
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
259
    auto pred_fun = predictor.GetPredictFunction();
260
261
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
262
    for (int i = 0; i < nrow; ++i) {
263
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
264
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
265
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
266
      pred_fun(one_row, pred_wrt_ptr);
267
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
268
    }
269
    OMP_THROW_EX();
270
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
271
272
273
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
274
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
275
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
279
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
284
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
285
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
286
287
288
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
289
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
290
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
291
292
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
293
294
  }

Guolin Ke's avatar
Guolin Ke committed
295
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
296
297
298
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

299
300
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
301
  }
302

303
  void LoadModelFromString(const char* model_str) {
304
305
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
306
307
  }

308
309
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
310
311
  }

312
  std::string DumpModel(int start_iteration, int num_iteration) {
313
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
314
  }
315

316
317
318
319
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
320
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
321
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
326
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
327
328
  }

329
  void ShuffleModels(int start_iter, int end_iter) {
330
    std::lock_guard<std::mutex> lock(mutex_);
331
    boosting_->ShuffleModels(start_iter, end_iter);
332
333
  }

wxchan's avatar
wxchan committed
334
335
336
337
338
339
340
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
341

wxchan's avatar
wxchan committed
342
343
344
345
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
346
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
347
348
349
350
351
352
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
353
354
355
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
356
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
357
358
359
360
361
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
362
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
363

Nikita Titov's avatar
Nikita Titov committed
364
 private:
wxchan's avatar
wxchan committed
365
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
366
  std::unique_ptr<Boosting> boosting_;
367
368
369
370
  std::unique_ptr<Predictor> single_row_predictor_;
  PredictFunction single_row_predict_function_;
  int64_t single_row_num_pred_in_one_row_;

Guolin Ke's avatar
Guolin Ke committed
371
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
372
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
373
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
374
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
375
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
376
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
377
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
378
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
379
380
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
381
382
};

383
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
384
385
386

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
391
392
393
394
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

395
396
397
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
398
399
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
400
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
401
402
403

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
404
 public:
Guolin Ke's avatar
Guolin Ke committed
405
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
406
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
407
408
409
410
411
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
412
413

 private:
Guolin Ke's avatar
Guolin Ke committed
414
415
416
417
418
419
420
421
422
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
423
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
424
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
425
426
}

Guolin Ke's avatar
Guolin Ke committed
427
int LGBM_DatasetCreateFromFile(const char* filename,
428
429
430
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
431
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
432
433
  auto param = Config::Str2Map(parameters);
  Config config;
434
435
436
437
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
438
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
439
  if (reference == nullptr) {
440
441
442
443
444
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
445
  } else {
446
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
447
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
448
  }
449
  API_END();
Guolin Ke's avatar
Guolin Ke committed
450
451
}

452

Guolin Ke's avatar
Guolin Ke committed
453
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
454
455
456
457
458
459
460
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
461
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
462
463
  auto param = Config::Str2Map(parameters);
  Config config;
464
465
466
467
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
468
  DatasetLoader loader(config, nullptr, 1, nullptr);
469
470
471
472
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
473
474
}

475

Guolin Ke's avatar
Guolin Ke committed
476
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
477
478
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
479
480
481
482
483
484
485
486
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
487
int LGBM_DatasetPushRows(DatasetHandle dataset,
488
489
490
491
492
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
493
494
495
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
496
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
497
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
498
  for (int i = 0; i < nrow; ++i) {
499
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
500
501
502
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
503
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
504
  }
505
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
506
507
508
509
510
511
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
512
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
513
514
515
516
517
518
519
520
521
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
522
523
524
525
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
526
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
527
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
528
  for (int i = 0; i < nrow; ++i) {
529
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
530
531
532
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
533
                          static_cast<data_size_t>(start_row + i), one_row);
534
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
535
  }
536
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
537
538
539
540
541
542
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
543
int LGBM_DatasetCreateFromMat(const void* data,
544
545
546
547
548
549
550
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
572
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
573
574
  auto param = Config::Str2Map(parameters);
  Config config;
575
576
577
578
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
579
  std::unique_ptr<Dataset> ret;
580
581
582
583
584
585
586
587
588
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
589

Guolin Ke's avatar
Guolin Ke committed
590
591
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
592
    Random rand(config.data_random_seed);
593
594
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
595
    sample_cnt = static_cast<int>(sample_indices.size());
596
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
597
    std::vector<std::vector<int>> sample_idx(ncol);
598
599
600

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
601
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
602
      auto idx = sample_indices[i];
603
604
605
606
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
607

608
609
610
611
612
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
613
        }
Guolin Ke's avatar
Guolin Ke committed
614
615
      }
    }
Guolin Ke's avatar
Guolin Ke committed
616
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
617
618
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
619
620
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
621
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
622
  } else {
623
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
624
    ret->CreateValid(
625
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
626
  }
627
628
629
630
631
632
633
634
635
636
637
638
639
640
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
641
642
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
643
  *out = ret.release();
644
  API_END();
645
646
}

Guolin Ke's avatar
Guolin Ke committed
647
int LGBM_DatasetCreateFromCSR(const void* indptr,
648
649
650
651
652
653
654
655
656
657
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
658
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
659
660
  auto param = Config::Str2Map(parameters);
  Config config;
661
662
663
664
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
665
  std::unique_ptr<Dataset> ret;
666
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
667
668
669
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
670
671
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
672
    auto sample_indices = rand.Sample(nrow, sample_cnt);
673
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
674
675
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
676
677
678
679
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
680
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
681
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
682
683
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
684
685
686
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
687
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
688
689
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
690
691
692
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
693
  } else {
694
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
695
    ret->CreateValid(
696
      reinterpret_cast<const Dataset*>(reference));
697
  }
698
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
699
  #pragma omp parallel for schedule(static)
700
  for (int i = 0; i < nindptr - 1; ++i) {
701
    OMP_LOOP_EX_BEGIN();
702
703
704
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
705
    OMP_LOOP_EX_END();
706
  }
707
  OMP_THROW_EX();
708
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
709
  *out = ret.release();
710
  API_END();
711
712
}

713
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
714
715
716
717
718
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
  API_BEGIN();

  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
763

764
765
766
767
768
769
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
770
771
772
      const int tid = omp_get_thread_num();
      get_row_fun(i, threadBuffer);
      ret->PushOneRow(tid, i, threadBuffer);
773
774
775
776
777
778
779
780
781
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
782
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
783
784
785
786
787
788
789
790
791
792
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
793
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
794
795
  auto param = Config::Str2Map(parameters);
  Config config;
796
797
798
799
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
800
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
801
802
803
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
804
805
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
806
    auto sample_indices = rand.Sample(nrow, sample_cnt);
807
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
808
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
809
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
810
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
811
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
812
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
813
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
814
815
816
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
817
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
818
819
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
820
821
        }
      }
822
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
823
    }
824
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
825
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
826
827
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
828
829
830
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
831
  } else {
832
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
833
    ret->CreateValid(
834
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
835
  }
836
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
837
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
838
  for (int i = 0; i < ncol_ptr - 1; ++i) {
839
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
840
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
841
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
842
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
843
844
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
845
846
847
848
849
850
851
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
852
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
853
    }
854
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
855
  }
856
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
857
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
858
  *out = ret.release();
859
  API_END();
Guolin Ke's avatar
Guolin Ke committed
860
861
}

Guolin Ke's avatar
Guolin Ke committed
862
int LGBM_DatasetGetSubset(
863
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
864
865
866
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
867
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
868
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
869
870
  auto param = Config::Str2Map(parameters);
  Config config;
871
872
873
874
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
875
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
876
  CHECK(num_used_row_indices > 0);
877
878
879
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
880
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
881
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
882
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
883
884
885
886
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
887
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
888
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
889
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
890
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
891
892
893
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
894
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
895
896
897
898
899
900
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
901
int LGBM_DatasetGetFeatureNames(
902
903
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
904
  int* num_feature_names) {
905
906
907
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
908
909
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
910
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
911
912
913
914
  }
  API_END();
}

915
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
916
int LGBM_DatasetFree(DatasetHandle handle) {
917
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
918
  delete reinterpret_cast<Dataset*>(handle);
919
  API_END();
920
921
}

Guolin Ke's avatar
Guolin Ke committed
922
int LGBM_DatasetSaveBinary(DatasetHandle handle,
923
                           const char* filename) {
924
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
925
926
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
927
  API_END();
928
929
}

930
931
932
933
934
935
936
937
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
938
int LGBM_DatasetSetField(DatasetHandle handle,
939
940
941
942
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
943
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
944
  auto dataset = reinterpret_cast<Dataset*>(handle);
945
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
946
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
947
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
948
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
949
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
950
951
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
952
  }
953
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
954
  API_END();
955
956
}

Guolin Ke's avatar
Guolin Ke committed
957
int LGBM_DatasetGetField(DatasetHandle handle,
958
959
960
961
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
962
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
963
  auto dataset = reinterpret_cast<Dataset*>(handle);
964
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
965
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
966
    *out_type = C_API_DTYPE_FLOAT32;
967
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
968
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
969
    *out_type = C_API_DTYPE_INT32;
970
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
971
972
973
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
974
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
975
976
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
977
  }
978
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
979
  if (*out_ptr == nullptr) { *out_len = 0; }
980
  API_END();
981
982
}

983
984
985
986
987
988
989
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
990
int LGBM_DatasetGetNumData(DatasetHandle handle,
991
                           int* out) {
992
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
993
994
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
995
  API_END();
996
997
}

Guolin Ke's avatar
Guolin Ke committed
998
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
999
                              int* out) {
1000
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1001
1002
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1003
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1004
}
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1015
1016
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1017
int LGBM_BoosterCreate(const DatasetHandle train_data,
1018
1019
                       const char* parameters,
                       BoosterHandle* out) {
1020
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1021
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1022
1023
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1024
  API_END();
1025
1026
}

Guolin Ke's avatar
Guolin Ke committed
1027
int LGBM_BoosterCreateFromModelfile(
1028
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1029
  int* out_num_iterations,
1030
  BoosterHandle* out) {
1031
  API_BEGIN();
wxchan's avatar
wxchan committed
1032
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1033
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1034
  *out = ret.release();
1035
  API_END();
1036
1037
}

Guolin Ke's avatar
Guolin Ke committed
1038
int LGBM_BoosterLoadModelFromString(
1039
1040
1041
1042
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1043
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1044
1045
1046
1047
1048
1049
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1050
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1051
int LGBM_BoosterFree(BoosterHandle handle) {
1052
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1053
  delete reinterpret_cast<Booster*>(handle);
1054
  API_END();
1055
1056
}

1057
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1058
1059
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1060
  ref_booster->ShuffleModels(start_iter, end_iter);
1061
1062
1063
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1064
int LGBM_BoosterMerge(BoosterHandle handle,
1065
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1066
1067
1068
1069
1070
1071
1072
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1073
int LGBM_BoosterAddValidData(BoosterHandle handle,
1074
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1075
1076
1077
1078
1079
1080
1081
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1082
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1083
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1084
1085
1086
1087
1088
1089
1090
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1091
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1092
1093
1094
1095
1096
1097
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1098
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1099
1100
1101
1102
1103
1104
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1105
1106
1107
1108
1109
1110
1111
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1112
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1113
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1114
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1115
1116
1117
1118
1119
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1120
  API_END();
1121
1122
}

Guolin Ke's avatar
Guolin Ke committed
1123
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1124
1125
1126
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1127
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1128
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1129
  #ifdef SCORE_T_USE_DOUBLE
1130
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1131
  #else
1132
1133
1134
1135
1136
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1137
  #endif
1138
  API_END();
1139
1140
}

Guolin Ke's avatar
Guolin Ke committed
1141
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1142
1143
1144
1145
1146
1147
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1148
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1149
1150
1151
1152
1153
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1154

1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1169
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1170
1171
1172
1173
1174
1175
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1176
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1177
1178
1179
1180
1181
1182
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1183
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1184
1185
1186
1187
1188
1189
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1190
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1191
1192
1193
1194
1195
1196
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1197
int LGBM_BoosterGetEval(BoosterHandle handle,
1198
1199
1200
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1201
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1202
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1203
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1204
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1205
  *out_len = static_cast<int>(result_buf.size());
1206
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1207
    (out_results)[i] = static_cast<double>(result_buf[i]);
1208
  }
1209
  API_END();
1210
1211
}

Guolin Ke's avatar
Guolin Ke committed
1212
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1213
1214
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1215
1216
1217
1218
1219
1220
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1221
int LGBM_BoosterGetPredict(BoosterHandle handle,
1222
1223
1224
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1225
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1226
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1227
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1228
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1229
1230
}

Guolin Ke's avatar
Guolin Ke committed
1231
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1232
1233
1234
1235
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1236
                               const char* parameter,
1237
                               const char* result_filename) {
1238
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1239
1240
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1241
1242
1243
1244
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1245
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1246
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1247
                       config, result_filename);
1248
  API_END();
1249
1250
}

Guolin Ke's avatar
Guolin Ke committed
1251
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1252
1253
1254
1255
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1256
1257
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1258
1259
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1260
1261
1262
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1263
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1274
                              const char* parameter,
1275
1276
                              int64_t* out_len,
                              double* out_result) {
1277
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1278
1279
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1280
1281
1282
1283
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1284
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1285
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1286
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1287
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1288
                       config, out_result, out_len);
1289
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1290
}
1291

1292
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
                                       int64_t,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1306
1307
1308
1309
1310
1311
1312
1313
1314
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1315
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
1316
1317
1318
1319
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1320
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1331
                              const char* parameter,
1332
1333
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1334
1335
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1336
1337
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1348
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1349
1350
1351
1352
1353
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1354
1355
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1356
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1357
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1358
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1359
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1360
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1361
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1362
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1363
1364
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1365
1366
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1367
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1368
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1369
1370
1371
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1372
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1373
1374
1375
1376
1377
1378
1379
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1380
                              const char* parameter,
1381
1382
                              int64_t* out_len,
                              double* out_result) {
1383
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1384
1385
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1386
1387
1388
1389
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1390
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1391
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1392
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1393
                       config, out_result, out_len);
1394
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1395
}
1396

1397
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1398
1399
1400
1401
1402
1403
1404
1405
1406
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1407
1408
1409
1410
1411
1412
1413
1414
1415
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1416
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
1417
1418
1419
1420
  API_END();
}


1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1440
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, config, out_result, out_len);
1441
1442
1443
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1444
int LGBM_BoosterSaveModel(BoosterHandle handle,
1445
                          int start_iteration,
1446
1447
                          int num_iteration,
                          const char* filename) {
1448
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1449
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1450
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1451
1452
1453
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1454
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1455
                                  int start_iteration,
1456
                                  int num_iteration,
1457
                                  int64_t buffer_len,
1458
                                  int64_t* out_len,
1459
                                  char* out_str) {
1460
1461
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1462
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1463
  *out_len = static_cast<int64_t>(model.size()) + 1;
1464
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1465
    std::memcpy(out_str, model.c_str(), *out_len);
1466
1467
1468
1469
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1470
int LGBM_BoosterDumpModel(BoosterHandle handle,
1471
                          int start_iteration,
1472
                          int num_iteration,
1473
1474
                          int64_t buffer_len,
                          int64_t* out_len,
1475
                          char* out_str) {
wxchan's avatar
wxchan committed
1476
1477
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1478
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1479
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1480
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1481
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1482
  }
1483
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1484
}
1485

Guolin Ke's avatar
Guolin Ke committed
1486
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1487
1488
1489
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1490
1491
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1492
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1493
1494
1495
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1496
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1497
1498
1499
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1500
1501
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1502
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1503
1504
1505
  API_END();
}

1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1519
1520
1521
1522
1523
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1524
  Config config;
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1541
1542
1543
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1544
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1545
  if (num_machines > 1) {
1546
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1547
1548
1549
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1550

Guolin Ke's avatar
Guolin Ke committed
1551
// ---- start of some help functions
1552
1553
1554

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1555
  if (data_type == C_API_DTYPE_FLOAT32) {
1556
1557
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1558
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1559
        std::vector<double> ret(num_col);
1560
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1561
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1562
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1563
1564
1565
1566
        }
        return ret;
      };
    } else {
1567
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1568
        std::vector<double> ret(num_col);
1569
        for (int i = 0; i < num_col; ++i) {
1570
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1571
1572
1573
1574
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1575
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1576
1577
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1578
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1579
        std::vector<double> ret(num_col);
1580
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1581
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1582
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1583
1584
1585
1586
        }
        return ret;
      };
    } else {
1587
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1588
        std::vector<double> ret(num_col);
1589
        for (int i = 0; i < num_col; ++i) {
1590
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1591
1592
1593
1594
1595
        }
        return ret;
      };
    }
  }
1596
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1597
1598
1599
1600
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1601
1602
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1603
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1604
1605
1606
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1607
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1608
          ret.emplace_back(i, raw_values[i]);
1609
        }
Guolin Ke's avatar
Guolin Ke committed
1610
1611
1612
      }
      return ret;
    };
1613
  }
Guolin Ke's avatar
Guolin Ke committed
1614
  return nullptr;
1615
1616
}

1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1633
std::function<std::vector<std::pair<int, double>>(int idx)>
1634
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1635
  if (data_type == C_API_DTYPE_FLOAT32) {
1636
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1637
    if (indptr_type == C_API_DTYPE_INT32) {
1638
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1639
      return [=] (int idx) {
1640
1641
1642
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1643
1644
1645
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1646
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1647
          ret.emplace_back(indices[i], data_ptr[i]);
1648
1649
1650
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1651
    } else if (indptr_type == C_API_DTYPE_INT64) {
1652
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1653
      return [=] (int idx) {
1654
1655
1656
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1657
1658
1659
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1660
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1661
          ret.emplace_back(indices[i], data_ptr[i]);
1662
1663
1664
1665
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1666
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1667
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1668
    if (indptr_type == C_API_DTYPE_INT32) {
1669
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1670
      return [=] (int idx) {
1671
1672
1673
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1674
1675
1676
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1677
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1678
          ret.emplace_back(indices[i], data_ptr[i]);
1679
1680
1681
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1682
    } else if (indptr_type == C_API_DTYPE_INT64) {
1683
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1684
      return [=] (int idx) {
1685
1686
1687
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1688
1689
1690
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1691
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1692
          ret.emplace_back(indices[i], data_ptr[i]);
1693
1694
1695
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1696
1697
    }
  }
1698
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1699
1700
}

Guolin Ke's avatar
Guolin Ke committed
1701
std::function<std::pair<int, double>(int idx)>
1702
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1703
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1704
  if (data_type == C_API_DTYPE_FLOAT32) {
1705
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1706
    if (col_ptr_type == C_API_DTYPE_INT32) {
1707
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1708
1709
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1710
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1711
1712
1713
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1714
        }
Guolin Ke's avatar
Guolin Ke committed
1715
1716
1717
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1718
      };
Guolin Ke's avatar
Guolin Ke committed
1719
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1720
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1721
1722
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1723
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1724
1725
1726
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1727
        }
Guolin Ke's avatar
Guolin Ke committed
1728
1729
1730
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1731
      };
Guolin Ke's avatar
Guolin Ke committed
1732
    }
Guolin Ke's avatar
Guolin Ke committed
1733
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1734
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1735
    if (col_ptr_type == C_API_DTYPE_INT32) {
1736
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1737
1738
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1739
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1740
1741
1742
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1743
        }
Guolin Ke's avatar
Guolin Ke committed
1744
1745
1746
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1747
      };
Guolin Ke's avatar
Guolin Ke committed
1748
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1749
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1750
1751
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1752
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1753
1754
1755
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1756
        }
Guolin Ke's avatar
Guolin Ke committed
1757
1758
1759
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1760
      };
Guolin Ke's avatar
Guolin Ke committed
1761
1762
    }
  }
1763
  throw std::runtime_error("Unknown data type in CSC matrix");
1764
1765
}

Guolin Ke's avatar
Guolin Ke committed
1766
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1767
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1768
1769
1770
1771
1772
1773
1774
1775
1776
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1777
    }
Guolin Ke's avatar
Guolin Ke committed
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1795
    }
Guolin Ke's avatar
Guolin Ke committed
1796
1797
1798
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1799
  }
Guolin Ke's avatar
Guolin Ke committed
1800
}