c_api.cpp 50.1 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
  Predictor NewPredictor(int num_iteration, int predict_type) {
wxchan's avatar
wxchan committed
164
165
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
166
167
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
168
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
169
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
170
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
171
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
172
173
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
174
    }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
    // not threading safe now
    // boosting_->SetNumIterationForPred may be set by other thread during prediction. 
    return Predictor(boosting_.get(), is_raw_score, is_predict_leaf);
Guolin Ke's avatar
Guolin Ke committed
178
179
  }

Guolin Ke's avatar
Guolin Ke committed
180
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
182
183
184
185
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
186
  }
187

188
189
190
191
192
193
194
195
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

196
197
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
198
  }
199

Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
209
210
211
212
213
214
215
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
216

wxchan's avatar
wxchan committed
217
218
219
220
221
222
223
224
225
226
227
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
228
229
230
231
232
233
234
235
236
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
237
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
238

Guolin Ke's avatar
Guolin Ke committed
239
private:
240

wxchan's avatar
wxchan committed
241
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
242
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
243
244
245
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
246
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
247
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
248
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
249
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
250
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
251
252
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
253
254
255
};

}
Guolin Ke's avatar
Guolin Ke committed
256
257
258

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
264
265
266
267
268
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
269
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
275
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

291
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
292
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
293
294
}

295
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
296
297
298
                                                 const char* parameters,
                                                 const DatasetHandle reference,
                                                 DatasetHandle* out) {
299
  API_BEGIN();
wxchan's avatar
wxchan committed
300
301
302
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
303
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
304
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
305
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
306
  } else {
Guolin Ke's avatar
Guolin Ke committed
307
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
308
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
309
  }
310
  API_END();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
                                                          int** sample_indices,
                                                          int32_t ncol,
                                                          const int* num_per_col,
                                                          int32_t num_sample_row,
                                                          int32_t num_total_row,
                                                          const char* parameters,
                                                          DatasetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr, 1, nullptr);
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
331
332
}

333

Guolin Ke's avatar
Guolin Ke committed
334
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
335
336
                                                    int64_t num_total_row,
                                                    DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
337
338
339
340
341
342
343
344
345
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
346
347
348
349
350
                                           const void* data,
                                           int data_type,
                                           int32_t nrow,
                                           int32_t ncol,
                                           int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
  }
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
367
368
369
370
371
372
373
374
375
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
379
380
381
382
383
384
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
#pragma omp parallel for schedule(static)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
385
                          static_cast<data_size_t>(start_row + i), one_row);
Guolin Ke's avatar
Guolin Ke committed
386
387
388
389
390
391
392
  }
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

393
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
394
395
396
397
398
399
400
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
401
  API_BEGIN();
wxchan's avatar
wxchan committed
402
403
404
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
405
  std::unique_ptr<Dataset> ret;
406
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
407
408
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
409
410
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
411
    auto sample_indices = rand.Sample(nrow, sample_cnt);
412
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
413
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
414
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
415
      auto idx = sample_indices[i];
416
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
417
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
418
419
420
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
421
        }
Guolin Ke's avatar
Guolin Ke committed
422
423
      }
    }
424
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
425
426
427
428
429
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values), 
                                            Common::Vector2Ptr<int>(sample_idx), 
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
430
  } else {
431
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
432
    ret->CreateValid(
433
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
434
435
  }

Guolin Ke's avatar
Guolin Ke committed
436
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
437
438
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
439
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
440
441
442
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
443
  *out = ret.release();
444
  API_END();
445
446
}

447
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
448
449
450
451
452
453
454
455
456
457
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t num_col,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
458
  API_BEGIN();
wxchan's avatar
wxchan committed
459
460
461
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
462
  std::unique_ptr<Dataset> ret;
463
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
464
465
466
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
467
468
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
469
470
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
471
    std::vector<std::vector<int>> sample_idx;
472
473
474
475
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
476
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
477
478
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
479
        }
Guolin Ke's avatar
Guolin Ke committed
480
481
482
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
483
484
485
        }
      }
    }
486
    CHECK(num_col >= static_cast<int>(sample_values.size()));
487
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
488
489
490
491
492
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
                                            Common::Vector2Ptr<int>(sample_idx),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
493
  } else {
494
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
495
    ret->CreateValid(
496
      reinterpret_cast<const Dataset*>(reference));
497
498
  }

Guolin Ke's avatar
Guolin Ke committed
499
#pragma omp parallel for schedule(static)
500
501
502
503
504
505
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
506
  *out = ret.release();
507
  API_END();
508
509
}

510
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
511
512
513
514
515
516
517
518
519
520
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                const char* parameters,
                                                const DatasetHandle reference,
                                                DatasetHandle* out) {
521
  API_BEGIN();
wxchan's avatar
wxchan committed
522
523
524
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
525
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
526
527
528
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
529
530
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
531
532
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
533
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
534
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
535
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
536
537
538
539
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
540
541
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
542
543
        }
      }
Guolin Ke's avatar
Guolin Ke committed
544
    }
545
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
546
547
548
549
550
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
                                            Common::Vector2Ptr<int>(sample_idx),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
551
  } else {
552
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
553
    ret->CreateValid(
554
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
555
556
  }

Guolin Ke's avatar
Guolin Ke committed
557
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
558
559
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
560
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
561
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
562
563
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
564
565
566
567
568
569
570
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
571
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
572
    }
Guolin Ke's avatar
Guolin Ke committed
573
574
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
575
  *out = ret.release();
576
  API_END();
Guolin Ke's avatar
Guolin Ke committed
577
578
}

579
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
580
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
581
582
583
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
584
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
585
586
587
588
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
589
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
590
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
591
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
592
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
593
594
595
596
  *out = ret.release();
  API_END();
}

597
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
598
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
599
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
600
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
601
602
603
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
604
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
605
606
607
608
609
610
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

611
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
612
613
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
614
  int* num_feature_names) {
615
616
617
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
618
619
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
620
621
622
623
624
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

625
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) {
626
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
627
  delete reinterpret_cast<Dataset*>(handle);
628
  API_END();
629
630
}

631
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
632
                                             const char* filename) {
633
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
634
635
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
636
  API_END();
637
638
}

639
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
640
641
642
643
                                           const char* field_name,
                                           const void* field_data,
                                           int num_element,
                                           int type) {
644
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
645
  auto dataset = reinterpret_cast<Dataset*>(handle);
646
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
647
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
648
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
649
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
650
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
651
652
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
653
  }
654
655
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
656
657
}

658
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
659
660
661
662
                                           const char* field_name,
                                           int* out_len,
                                           const void** out_ptr,
                                           int* out_type) {
663
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
664
  auto dataset = reinterpret_cast<Dataset*>(handle);
665
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
666
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
667
    *out_type = C_API_DTYPE_FLOAT32;
668
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
669
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
670
    *out_type = C_API_DTYPE_INT32;
671
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
672
673
674
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
675
  }
676
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
677
  if (*out_ptr == nullptr) { *out_len = 0; }
678
  API_END();
679
680
}

681
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
682
                                             int* out) {
683
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
684
685
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
686
  API_END();
687
688
}

689
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
690
                                                int* out) {
691
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
692
693
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
694
  API_END();
Guolin Ke's avatar
Guolin Ke committed
695
}
696
697
698

// ---- start of booster

699
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
700
701
                                         const char* parameters,
                                         BoosterHandle* out) {
702
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
703
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
704
705
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
706
  API_END();
707
708
}

709
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
710
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
711
  int* out_num_iterations,
712
  BoosterHandle* out) {
713
  API_BEGIN();
wxchan's avatar
wxchan committed
714
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
715
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
716
  *out = ret.release();
717
  API_END();
718
719
}

720
721
722
723
724
725
726
727
728
729
730
731
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

732
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) {
733
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
734
  delete reinterpret_cast<Booster*>(handle);
735
  API_END();
736
737
}

738
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
739
                                        BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
740
741
742
743
744
745
746
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

747
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
748
                                               const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
749
750
751
752
753
754
755
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

756
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
757
                                                    const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
758
759
760
761
762
763
764
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

765
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
766
767
768
769
770
771
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

772
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
773
774
775
776
777
778
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

779
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
780
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
781
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
782
783
784
785
786
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
787
  API_END();
788
789
}

790
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
791
792
793
                                                      const float* grad,
                                                      const float* hess,
                                                      int* is_finished) {
794
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
795
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
796
797
798
799
800
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
801
  API_END();
802
803
}

804
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
805
806
807
808
809
810
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

811
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
812
813
814
815
816
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
817

818
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
819
820
821
822
823
824
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

825
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
826
827
828
829
830
831
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

wxchan's avatar
wxchan committed
832
833
834
835
836
837
838
839
840
841
842
843
844
845
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

846
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
847
848
849
                                          int data_idx,
                                          int* out_len,
                                          double* out_results) {
850
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
851
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
852
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
853
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
854
  *out_len = static_cast<int>(result_buf.size());
855
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
856
    (out_results)[i] = static_cast<double>(result_buf[i]);
857
  }
858
  API_END();
859
860
}

861
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
862
863
                                                int data_idx,
                                                int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
864
865
866
867
868
869
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

870
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
871
872
873
                                             int data_idx,
                                             int64_t* out_len,
                                             double* out_result) {
874
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
875
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
876
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
877
  API_END();
Guolin Ke's avatar
Guolin Ke committed
878
879
}

880
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
881
882
883
884
885
                                                 const char* data_filename,
                                                 int data_has_header,
                                                 int predict_type,
                                                 int num_iteration,
                                                 const char* result_filename) {
886
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
887
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
888
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
889
  bool bool_data_has_header = data_has_header > 0 ? true : false;
Guolin Ke's avatar
Guolin Ke committed
890
  predictor.Predict(data_filename, result_filename, bool_data_has_header);
891
  API_END();
892
893
}

894
895
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) {
  int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
896
897
898
899
900
901
902
903
904
905
906
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
    } else {
      num_preb_in_one_row *= max_iteration;
    }
  }
  return num_preb_in_one_row;
}

907
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
908
909
910
911
                                                 int num_row,
                                                 int predict_type,
                                                 int num_iteration,
                                                 int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
912
913
914
915
916
917
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration));
  API_END();
}

918
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
919
920
921
922
923
924
925
926
927
928
929
930
                                                const void* indptr,
                                                int indptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t nindptr,
                                                int64_t nelem,
                                                int64_t,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
931
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
932
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
933
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
934
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
935
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
936
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
937
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
938
939
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
940
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
941
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
942
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
943
944
    }
  }
wxchan's avatar
wxchan committed
945
  *out_len = nrow * num_preb_in_one_row;
946
  API_END();
Guolin Ke's avatar
Guolin Ke committed
947
}
948

949
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
950
951
952
953
954
955
956
957
958
959
960
961
                                                const void* col_ptr,
                                                int col_ptr_type,
                                                const int32_t* indices,
                                                const void* data,
                                                int data_type,
                                                int64_t ncol_ptr,
                                                int64_t nelem,
                                                int64_t num_row,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
962
963
964
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
965
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
966
967
968
  int ncol = static_cast<int>(ncol_ptr - 1);

  Threading::For<int64_t>(0, num_row,
969
                          [&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
Guolin Ke's avatar
Guolin Ke committed
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
  (int, data_size_t start, data_size_t end) {
    std::vector<CSC_RowIterator> iterators;
    for (int j = 0; j < ncol; ++j) {
      iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
    std::vector<std::pair<int, double>> one_row;
    for (int64_t i = start; i < end; ++i) {
      one_row.clear();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[j].Get(static_cast<int>(i));
        if (std::fabs(val) > kEpsilon) {
          one_row.emplace_back(j, val);
        }
      }
      auto predicton_result = predictor.GetPredictFunction()(one_row);
      for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
        out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
      }
    }
  });
  *out_len = num_row * num_preb_in_one_row;
  API_END();
}

994
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
995
996
997
998
999
1000
1001
1002
1003
                                                const void* data,
                                                int data_type,
                                                int32_t nrow,
                                                int32_t ncol,
                                                int is_row_major,
                                                int predict_type,
                                                int num_iteration,
                                                int64_t* out_len,
                                                double* out_result) {
1004
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1005
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1006
  auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
1007
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1008
  int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1009
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1010
1011
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
1012
    auto predicton_result = predictor.GetPredictFunction()(one_row);
wxchan's avatar
wxchan committed
1013
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
Guolin Ke's avatar
Guolin Ke committed
1014
      out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
1015
1016
    }
  }
wxchan's avatar
wxchan committed
1017
  *out_len = nrow * num_preb_in_one_row;
1018
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1019
}
1020

1021
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
1022
1023
                                            int num_iteration,
                                            const char* filename) {
1024
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1025
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1026
1027
1028
1029
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

1030
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1031
1032
1033
1034
                                                    int num_iteration,
                                                    int buffer_len,
                                                    int* out_len,
                                                    char* out_str) {
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

1045
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
1046
1047
1048
1049
                                            int num_iteration,
                                            int buffer_len,
                                            int* out_len,
                                            char* out_str) {
wxchan's avatar
wxchan committed
1050
1051
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1052
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1053
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1054
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1055
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1056
  }
1057
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1058
}
1059

1060
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1061
1062
1063
                                               int tree_idx,
                                               int leaf_idx,
                                               double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1064
1065
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1066
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1067
1068
1069
  API_END();
}

1070
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1071
1072
1073
                                               int tree_idx,
                                               int leaf_idx,
                                               double val) {
Guolin Ke's avatar
Guolin Ke committed
1074
1075
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1076
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1077
1078
1079
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1080
// ---- start of some help functions
1081
1082
1083

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1084
  if (data_type == C_API_DTYPE_FLOAT32) {
1085
1086
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1087
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1088
        std::vector<double> ret(num_col);
1089
1090
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1091
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1092
1093
1094
1095
        }
        return ret;
      };
    } else {
1096
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1097
        std::vector<double> ret(num_col);
1098
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1099
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1100
1101
1102
1103
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1104
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1105
1106
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1107
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1108
        std::vector<double> ret(num_col);
1109
1110
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1111
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1112
1113
1114
1115
        }
        return ret;
      };
    } else {
1116
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1117
        std::vector<double> ret(num_col);
1118
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1119
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
1120
1121
1122
1123
1124
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1125
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1126
1127
1128
1129
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1130
1131
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1132
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1133
1134
1135
1136
1137
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1138
        }
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
      }
      return ret;
    };
1142
  }
Guolin Ke's avatar
Guolin Ke committed
1143
  return nullptr;
1144
1145
1146
1147
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1148
  if (data_type == C_API_DTYPE_FLOAT32) {
1149
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1150
    if (indptr_type == C_API_DTYPE_INT32) {
1151
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1152
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1153
1154
1155
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1156
        for (int64_t i = start; i < end; ++i) {
1157
1158
1159
1160
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1161
    } else if (indptr_type == C_API_DTYPE_INT64) {
1162
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1163
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1164
1165
1166
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1167
        for (int64_t i = start; i < end; ++i) {
1168
1169
1170
1171
1172
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1173
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1174
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1175
    if (indptr_type == C_API_DTYPE_INT32) {
1176
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1177
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1178
1179
1180
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1181
        for (int64_t i = start; i < end; ++i) {
1182
1183
1184
1185
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1186
    } else if (indptr_type == C_API_DTYPE_INT64) {
1187
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1188
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1189
1190
1191
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1192
        for (int64_t i = start; i < end; ++i) {
1193
1194
1195
1196
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1197
1198
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1199
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1200
1201
}

Guolin Ke's avatar
Guolin Ke committed
1202
1203
1204
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1205
  if (data_type == C_API_DTYPE_FLOAT32) {
1206
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1207
    if (col_ptr_type == C_API_DTYPE_INT32) {
1208
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1209
1210
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1211
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1212
1213
1214
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1215
        }
Guolin Ke's avatar
Guolin Ke committed
1216
1217
1218
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1219
      };
Guolin Ke's avatar
Guolin Ke committed
1220
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1221
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1222
1223
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1224
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1225
1226
1227
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1228
        }
Guolin Ke's avatar
Guolin Ke committed
1229
1230
1231
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1232
      };
Guolin Ke's avatar
Guolin Ke committed
1233
    }
Guolin Ke's avatar
Guolin Ke committed
1234
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1235
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1236
    if (col_ptr_type == C_API_DTYPE_INT32) {
1237
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1238
1239
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1240
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1241
1242
1243
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1244
        }
Guolin Ke's avatar
Guolin Ke committed
1245
1246
1247
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1248
      };
Guolin Ke's avatar
Guolin Ke committed
1249
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1250
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1251
1252
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1253
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1254
1255
1256
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1257
        }
Guolin Ke's avatar
Guolin Ke committed
1258
1259
1260
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1261
      };
Guolin Ke's avatar
Guolin Ke committed
1262
1263
1264
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1265
1266
}

Guolin Ke's avatar
Guolin Ke committed
1267
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1268
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1269
1270
1271
1272
1273
1274
1275
1276
1277
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1278
    }
Guolin Ke's avatar
Guolin Ke committed
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1296
    }
Guolin Ke's avatar
Guolin Ke committed
1297
1298
1299
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1300
  }
Guolin Ke's avatar
Guolin Ke committed
1301
}