"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "bb88d92ef27acc0731bf3b348920560410aa2e78"
c_api.cpp 29.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19

Guolin Ke's avatar
Guolin Ke committed
20
21
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
26
27
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
28
29
30
  }

  Booster(const Dataset* train_data, 
31
    const char* parameters) {
32
33
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
34
35
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
36
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
37
38
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
39
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, ""));
40
41
42
43
44
45
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
  void MergeFrom(const Booster* other) {
    boosting_->MergeFrom(other->boosting_.get());
  }

50
51
52
53
54
  ~Booster() {

  }

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
55
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
56
57
58
59
60
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
Guolin Ke's avatar
Guolin Ke committed
61
    // create training metric
62
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
63
    for (auto metric_type : config_.metric_types) {
Guolin Ke's avatar
Guolin Ke committed
64
65
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
Guolin Ke's avatar
Guolin Ke committed
66
      if (metric == nullptr) { continue; }
67
      metric->Init(train_data->metadata(), train_data->num_data());
Guolin Ke's avatar
Guolin Ke committed
68
      train_metric_.push_back(std::move(metric));
Guolin Ke's avatar
Guolin Ke committed
69
    }
Guolin Ke's avatar
Guolin Ke committed
70
    train_metric_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
71
    // initialize the objective function
Guolin Ke's avatar
Guolin Ke committed
72
    if (objective_fun_ != nullptr) {
73
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
Guolin Ke's avatar
Guolin Ke committed
74
75
    }
  }
Guolin Ke's avatar
Guolin Ke committed
76

77
  void ResetTrainingData(const Dataset* train_data) {
78
79
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
80
    // initialize the boosting
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
95
  }
Guolin Ke's avatar
Guolin Ke committed
96

97
98
99
100
101
102
103
104
105
106
107
  void AddValidData(const Dataset* valid_data) {
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
108
  }
109
110
111
112
113
114
115
116
  bool TrainOneIter() {
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

Guolin Ke's avatar
Guolin Ke committed
117
118
  void PrepareForPrediction(int num_iteration, int predict_type) {
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
119
120
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
121
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
122
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
123
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
124
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
125
126
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
129
130
  }

Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
135
136
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
137
138
  }

139
140
141
142
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
143
144
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, true, filename);
Guolin Ke's avatar
Guolin Ke committed
145
  }
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
153
154

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
155
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
156
157
158
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
        int j = 0;
        auto name_cstr = name.c_str();
        while (name_cstr[j] != '\0') {
          out_strs[idx][j] = name_cstr[j];
          ++j;
        }
        out_strs[idx][j] = '\0';
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
171
      }
    }
    return idx;
  }

172
173
174
175
176

  void RollbackOneIter() {
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
177
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
178
  
Guolin Ke's avatar
Guolin Ke committed
179
private:
180
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
181
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
182
183
184
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
185
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
186
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
187
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
188
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
189
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
190
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
191
  std::unique_ptr<Predictor> predictor_;
192

Guolin Ke's avatar
Guolin Ke committed
193
194
195
};

}
Guolin Ke's avatar
Guolin Ke committed
196
197
198

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
199
DllExport const char* LGBM_GetLastError() {
200
  return LastErrorMsg().c_str();
Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
205
206
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
207
  API_BEGIN();
208
209
210
211
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
212
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
213
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
214
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
215
  } else {
Guolin Ke's avatar
Guolin Ke committed
216
217
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
218
  }
219
  API_END();
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
}

DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
  DatesetHandle* out) {
224
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
225
226
  OverallConfig config;
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
227
  *out = loader.LoadFromBinFile(filename, 0, 1);
228
  API_END();
Guolin Ke's avatar
Guolin Ke committed
229
230
231
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
232
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
237
238
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
239
  API_BEGIN();
240
241
242
243
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
244
  std::unique_ptr<Dataset> ret;
245
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
246
247
  if (reference == nullptr) {
    // sample data first
248
249
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
250
    auto sample_indices = rand.Sample(nrow, sample_cnt);
251
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
252
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
253
      auto idx = sample_indices[i];
254
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
255
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
256
257
258
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
259
260
      }
    }
Guolin Ke's avatar
Guolin Ke committed
261
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
262
  } else {
263
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
264
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
265
      reinterpret_cast<const Dataset*>(*reference),
266
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
271
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
272
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
273
274
275
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
276
  *out = ret.release();
277
  API_END();
278
279
}

280
281
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
282
283
  const int32_t* indices,
  const void* data,
284
285
286
287
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
288
289
290
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
291
  API_BEGIN();
292
293
294
295
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
296
  std::unique_ptr<Dataset> ret;
297
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
298
299
300
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
301
302
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
303
304
305
306
307
308
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
315
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
316
          }
Guolin Ke's avatar
Guolin Ke committed
317
318
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
319
320
321
        }
      }
    }
322
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
323
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
324
  } else {
325
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
326
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
327
      reinterpret_cast<const Dataset*>(*reference),
328
      io_config.is_enable_sparse);
329
330
331
332
333
334
335
336
337
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
338
  *out = ret.release();
339
  API_END();
340
341
}

342
343
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
344
345
  const int32_t* indices,
  const void* data,
346
347
348
349
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
350
351
352
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
353
  API_BEGIN();
354
355
356
357
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
358
  std::unique_ptr<Dataset> ret;
359
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
360
361
362
363
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
364
365
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
366
367
368
369
370
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
371
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
372
    }
Guolin Ke's avatar
Guolin Ke committed
373
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
374
  } else {
375
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
376
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
377
      reinterpret_cast<const Dataset*>(*reference),
378
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
379
380
381
382
383
384
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
385
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
386
387
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
388
  *out = ret.release();
389
  API_END();
Guolin Ke's avatar
Guolin Ke committed
390
391
}

392
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
393
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
394
  delete reinterpret_cast<Dataset*>(handle);
395
  API_END();
396
397
398
399
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
400
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
401
402
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
403
  API_END();
404
405
406
407
408
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
409
  int64_t num_element,
410
  int type) {
411
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
412
  auto dataset = reinterpret_cast<Dataset*>(handle);
413
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
414
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
415
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
416
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
417
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
418
  }
419
420
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
421
422
423
424
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
425
  int64_t* out_len,
426
427
  const void** out_ptr,
  int* out_type) {
428
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
429
  auto dataset = reinterpret_cast<Dataset*>(handle);
430
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
431
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
432
    *out_type = C_API_DTYPE_FLOAT32;
433
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
434
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
435
    *out_type = C_API_DTYPE_INT32;
436
    is_success = true;
437
  }
438
439
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
440
  API_END();
441
442
443
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
444
  int64_t* out) {
445
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
446
447
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
448
  API_END();
449
450
451
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
452
  int64_t* out) {
453
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
454
455
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
456
  API_END();
Guolin Ke's avatar
Guolin Ke committed
457
}
458
459
460
461
462
463
464


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
465
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
466
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
467
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
468
  *out = ret.release();
469
  API_END();
470
471
}

472
DllExport int LGBM_BoosterCreateFromModelfile(
473
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
474
  int64_t* num_total_model,
475
  BoosterHandle* out) {
476
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
477
478
479
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
  *out = ret.release();
480
  API_END();
481
482
483
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
484
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
485
  delete reinterpret_cast<Booster*>(handle);
486
  API_END();
487
488
}

Guolin Ke's avatar
Guolin Ke committed
489
490
491
492
493
494
495
496
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

516
517
518
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
519
  ref_booster->ResetConfig(parameters);
520
521
522
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
523
524
525
526
527
528
529
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

530
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
531
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
532
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
533
534
535
536
537
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
538
  API_END();
539
540
541
542
543
544
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
545
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
546
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
547
548
549
550
551
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
552
  API_END();
553
554
}

555
556
557
558
559
560
561
562
563
564
565
566
567
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
583
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
584
585
586
587
588
589
590
591
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
592
  int data,
593
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
594
  float* out_results) {
595
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
596
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
597
598
  auto boosting = ref_booster->GetBoosting();
  auto result_buf = boosting->GetEvalAt(data);
599
  *out_len = static_cast<int64_t>(result_buf.size());
600
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
601
    (out_results)[i] = static_cast<float>(result_buf[i]);
602
  }
603
  API_END();
604
605
}

Guolin Ke's avatar
Guolin Ke committed
606
607
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
  int data,
608
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
609
  float* out_result) {
610
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
611
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
612
  int len = 0;
Guolin Ke's avatar
Guolin Ke committed
613
  ref_booster->GetPredictAt(data, out_result, &len);
614
  *out_len = static_cast<int64_t>(len);
615
  API_END();
Guolin Ke's avatar
Guolin Ke committed
616
617
}

618
619
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
620
621
622
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
623
  const char* result_filename) {
624
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
625
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
626
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
627
628
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
629
  API_END();
630
631
}

632
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
633
634
  const void* indptr,
  int indptr_type,
635
636
  const int32_t* indices,
  const void* data,
637
638
639
640
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
641
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
642
643
644
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
645
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
646
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
647
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
648

649
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
650
651
652
653
654
655
656
657
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
658
659
660
661
662
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
663
664
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
665
666
    }
  }
Guolin Ke's avatar
Guolin Ke committed
667
  *out_len = nrow * num_preb_in_one_row;
668
  API_END();
Guolin Ke's avatar
Guolin Ke committed
669
}
670
671
672

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
673
  int data_type,
674
675
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
676
  int is_row_major,
677
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
678
679
680
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
681
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
682
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
683
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
684

685
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
686
687
688
689
690
691
692
693
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
694
695
696
697
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
698
699
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
700
701
    }
  }
Guolin Ke's avatar
Guolin Ke committed
702
  *out_len = nrow * num_preb_in_one_row;
703
  API_END();
Guolin Ke's avatar
Guolin Ke committed
704
}
705
706

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
707
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
708
  const char* filename) {
709
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
710
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
711
  ref_booster->SaveModelToFile(num_iteration, filename);
712
  API_END();
Guolin Ke's avatar
Guolin Ke committed
713
}
714

Guolin Ke's avatar
Guolin Ke committed
715
// ---- start of some help functions
716
717
718

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
719
  if (data_type == C_API_DTYPE_FLOAT32) {
720
721
722
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
723
        std::vector<double> ret(num_col);
724
725
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
726
          ret[i] = static_cast<double>(*(tmp_ptr + i));
727
728
729
730
731
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
732
        std::vector<double> ret(num_col);
733
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
734
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
735
736
737
738
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
739
  } else if (data_type == C_API_DTYPE_FLOAT64) {
740
741
742
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
743
        std::vector<double> ret(num_col);
744
745
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
746
          ret[i] = static_cast<double>(*(tmp_ptr + i));
747
748
749
750
751
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
752
        std::vector<double> ret(num_col);
753
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
754
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
755
756
757
758
759
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
760
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
761
762
763
764
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
765
766
767
768
769
770
771
772
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
773
        }
Guolin Ke's avatar
Guolin Ke committed
774
775
776
      }
      return ret;
    };
777
  }
Guolin Ke's avatar
Guolin Ke committed
778
  return nullptr;
779
780
781
782
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
783
  if (data_type == C_API_DTYPE_FLOAT32) {
784
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
785
    if (indptr_type == C_API_DTYPE_INT32) {
786
787
788
789
790
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
791
        for (int64_t i = start; i < end; ++i) {
792
793
794
795
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
796
    } else if (indptr_type == C_API_DTYPE_INT64) {
797
798
799
800
801
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
802
        for (int64_t i = start; i < end; ++i) {
803
804
805
806
807
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
808
  } else if (data_type == C_API_DTYPE_FLOAT64) {
809
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
810
    if (indptr_type == C_API_DTYPE_INT32) {
811
812
813
814
815
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
816
        for (int64_t i = start; i < end; ++i) {
817
818
819
820
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
821
    } else if (indptr_type == C_API_DTYPE_INT64) {
822
823
824
825
826
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
827
        for (int64_t i = start; i < end; ++i) {
828
829
830
831
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
832
833
834
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
835
836
837
838
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
839
  if (data_type == C_API_DTYPE_FLOAT32) {
840
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
841
    if (col_ptr_type == C_API_DTYPE_INT32) {
842
843
844
845
846
847
848
849
850
851
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
852
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
853
854
855
856
857
858
859
860
861
862
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
863
    } 
Guolin Ke's avatar
Guolin Ke committed
864
  } else if (data_type == C_API_DTYPE_FLOAT64) {
865
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
866
    if (col_ptr_type == C_API_DTYPE_INT32) {
867
868
869
870
871
872
873
874
875
876
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
877
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
878
879
880
881
882
883
884
885
886
887
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
888
889
890
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
891
892
}

893
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
894
895
896
897
898
899
900
901
902
903
904
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
905
}