c_api.cpp 29.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
33
34
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
40
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
41
42
43
44
45
46
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
47
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
48
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
49
50
51
    boosting_->MergeFrom(other->boosting_.get());
  }

52
53
54
55
56
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
57
    std::lock_guard<std::mutex> lock(mutex_);
58
59
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
60
    // initialize the boosting
61
62
63
64
65
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
66
    std::lock_guard<std::mutex> lock(mutex_);
67
68
69
70
71
72
73
74
75
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
76
  }
Guolin Ke's avatar
Guolin Ke committed
77

78
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
79
    std::lock_guard<std::mutex> lock(mutex_);
80
81
82
83
84
85
86
87
88
89
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
90
  }
91
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
92
93
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(nullptr, nullptr, false);
94
95
96
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
97
98
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(gradients, hessians, false);
99
100
  }

Guolin Ke's avatar
Guolin Ke committed
101
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
102
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
106
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
107
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
108
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
109
110
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
111
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
112
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
113
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
114
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
115
116
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
117
    }
Guolin Ke's avatar
Guolin Ke committed
118
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
119
120
  }

Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
125
126
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
127
128
  }

129
130
131
132
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
133
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
134
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
135
  }
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140
141
142
143
144

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
145
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
149
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
150
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
156
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
157
  
Guolin Ke's avatar
Guolin Ke committed
158
private:
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

183
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
184
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
188
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
189
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
190
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
191
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
192
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
193
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
194
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
195
196
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
};

}
Guolin Ke's avatar
Guolin Ke committed
200
201
202

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
203
DllExport const char* LGBM_GetLastError() {
204
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
205
206
207
208
209
210
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
211
  API_BEGIN();
212
213
214
215
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
216
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
217
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
218
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
219
  } else {
Guolin Ke's avatar
Guolin Ke committed
220
221
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
222
  }
223
  API_END();
Guolin Ke's avatar
Guolin Ke committed
224
225
226
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
227
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
232
233
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
234
  API_BEGIN();
235
236
237
238
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
239
  std::unique_ptr<Dataset> ret;
240
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
241
242
  if (reference == nullptr) {
    // sample data first
243
244
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
245
    auto sample_indices = rand.Sample(nrow, sample_cnt);
246
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
247
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
248
      auto idx = sample_indices[i];
249
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
250
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
251
252
253
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
254
255
      }
    }
Guolin Ke's avatar
Guolin Ke committed
256
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
257
  } else {
258
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
259
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
260
      reinterpret_cast<const Dataset*>(*reference),
261
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
266
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
267
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
268
269
270
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
271
  *out = ret.release();
272
  API_END();
273
274
}

275
276
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
277
278
  const int32_t* indices,
  const void* data,
279
280
281
282
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
283
284
285
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
286
  API_BEGIN();
287
288
289
290
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
291
  std::unique_ptr<Dataset> ret;
292
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
293
294
295
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
296
297
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
298
299
300
301
302
303
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
304
305
306
307
308
309
310
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
311
          }
Guolin Ke's avatar
Guolin Ke committed
312
313
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
314
315
316
        }
      }
    }
317
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
318
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
319
  } else {
320
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
321
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
322
      reinterpret_cast<const Dataset*>(*reference),
323
      io_config.is_enable_sparse);
324
325
326
327
328
329
330
331
332
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
333
  *out = ret.release();
334
  API_END();
335
336
}

337
338
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
339
340
  const int32_t* indices,
  const void* data,
341
342
343
344
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
345
346
347
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
348
  API_BEGIN();
349
350
351
352
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
353
  std::unique_ptr<Dataset> ret;
354
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
359
360
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
366
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
367
    }
Guolin Ke's avatar
Guolin Ke committed
368
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
369
  } else {
370
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
371
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
372
      reinterpret_cast<const Dataset*>(*reference),
373
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
374
375
376
377
378
379
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
380
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
381
382
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
383
  *out = ret.release();
384
  API_END();
Guolin Ke's avatar
Guolin Ke committed
385
386
}

387
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
388
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
389
  delete reinterpret_cast<Dataset*>(handle);
390
  API_END();
391
392
393
394
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
395
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
396
397
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
398
  API_END();
399
400
401
402
403
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
404
  int64_t num_element,
405
  int type) {
406
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
407
  auto dataset = reinterpret_cast<Dataset*>(handle);
408
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
409
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
410
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
411
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
412
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
413
  }
414
415
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
416
417
418
419
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
420
  int64_t* out_len,
421
422
  const void** out_ptr,
  int* out_type) {
423
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
424
  auto dataset = reinterpret_cast<Dataset*>(handle);
425
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
426
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
427
    *out_type = C_API_DTYPE_FLOAT32;
428
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
429
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
430
    *out_type = C_API_DTYPE_INT32;
431
    is_success = true;
432
  }
433
434
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
435
  API_END();
436
437
438
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
439
  int64_t* out) {
440
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
441
442
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
443
  API_END();
444
445
446
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
447
  int64_t* out) {
448
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
449
450
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
451
  API_END();
Guolin Ke's avatar
Guolin Ke committed
452
}
453
454
455
456
457
458
459


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
460
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
461
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
462
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
463
  *out = ret.release();
464
  API_END();
465
466
}

467
DllExport int LGBM_BoosterCreateFromModelfile(
468
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
469
  int64_t* num_total_model,
470
  BoosterHandle* out) {
471
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
472
473
474
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
  *out = ret.release();
475
  API_END();
476
477
478
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
479
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
480
  delete reinterpret_cast<Booster*>(handle);
481
  API_END();
482
483
}

Guolin Ke's avatar
Guolin Ke committed
484
485
486
487
488
489
490
491
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

511
512
513
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
514
  ref_booster->ResetConfig(parameters);
515
516
517
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
522
523
524
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

525
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
526
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
527
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
528
529
530
531
532
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
533
  API_END();
534
535
536
537
538
539
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
540
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
541
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
542
543
544
545
546
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
547
  API_END();
548
549
}

550
551
552
553
554
555
556
557
558
559
560
561
562
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
578
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
579
580
581
582
583
584
585
586
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
587
  int data_idx,
588
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
589
  float* out_results) {
590
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
591
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
592
  auto boosting = ref_booster->GetBoosting();
Guolin Ke's avatar
typo  
Guolin Ke committed
593
  auto result_buf = boosting->GetEvalAt(data_idx);
594
  *out_len = static_cast<int64_t>(result_buf.size());
595
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
596
    (out_results)[i] = static_cast<float>(result_buf[i]);
597
  }
598
  API_END();
599
600
}

Guolin Ke's avatar
Guolin Ke committed
601
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
602
  int data_idx,
603
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
604
  float* out_result) {
605
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
606
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
607
  int len = 0;
Guolin Ke's avatar
typo  
Guolin Ke committed
608
  ref_booster->GetPredictAt(data_idx, out_result, &len);
609
  *out_len = static_cast<int64_t>(len);
610
  API_END();
Guolin Ke's avatar
Guolin Ke committed
611
612
}

613
614
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
615
616
617
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
618
  const char* result_filename) {
619
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
620
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
621
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
622
623
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
624
  API_END();
625
626
}

627
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
628
629
  const void* indptr,
  int indptr_type,
630
631
  const int32_t* indices,
  const void* data,
632
633
634
635
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
636
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
637
638
639
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
640
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
641
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
642
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
643

644
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
645
646
647
648
649
650
651
652
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
653
654
655
656
657
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
658
659
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
660
661
    }
  }
Guolin Ke's avatar
Guolin Ke committed
662
  *out_len = nrow * num_preb_in_one_row;
663
  API_END();
Guolin Ke's avatar
Guolin Ke committed
664
}
665
666
667

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
668
  int data_type,
669
670
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
671
  int is_row_major,
672
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
673
674
675
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
676
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
677
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
678
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
679

680
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
681
682
683
684
685
686
687
688
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
689
690
691
692
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
693
694
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
695
696
    }
  }
Guolin Ke's avatar
Guolin Ke committed
697
  *out_len = nrow * num_preb_in_one_row;
698
  API_END();
Guolin Ke's avatar
Guolin Ke committed
699
}
700
701

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
702
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
703
  const char* filename) {
704
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
705
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
706
  ref_booster->SaveModelToFile(num_iteration, filename);
707
  API_END();
Guolin Ke's avatar
Guolin Ke committed
708
}
709

Guolin Ke's avatar
Guolin Ke committed
710
// ---- start of some help functions
711
712
713

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
714
  if (data_type == C_API_DTYPE_FLOAT32) {
715
716
717
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
718
        std::vector<double> ret(num_col);
719
720
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
721
          ret[i] = static_cast<double>(*(tmp_ptr + i));
722
723
724
725
726
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
727
        std::vector<double> ret(num_col);
728
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
729
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
730
731
732
733
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
734
  } else if (data_type == C_API_DTYPE_FLOAT64) {
735
736
737
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
738
        std::vector<double> ret(num_col);
739
740
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
741
          ret[i] = static_cast<double>(*(tmp_ptr + i));
742
743
744
745
746
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
747
        std::vector<double> ret(num_col);
748
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
749
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
750
751
752
753
754
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
755
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
756
757
758
759
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
760
761
762
763
764
765
766
767
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
768
        }
Guolin Ke's avatar
Guolin Ke committed
769
770
771
      }
      return ret;
    };
772
  }
Guolin Ke's avatar
Guolin Ke committed
773
  return nullptr;
774
775
776
777
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
778
  if (data_type == C_API_DTYPE_FLOAT32) {
779
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
780
    if (indptr_type == C_API_DTYPE_INT32) {
781
782
783
784
785
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
786
        for (int64_t i = start; i < end; ++i) {
787
788
789
790
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
791
    } else if (indptr_type == C_API_DTYPE_INT64) {
792
793
794
795
796
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
797
        for (int64_t i = start; i < end; ++i) {
798
799
800
801
802
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
803
  } else if (data_type == C_API_DTYPE_FLOAT64) {
804
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
805
    if (indptr_type == C_API_DTYPE_INT32) {
806
807
808
809
810
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
811
        for (int64_t i = start; i < end; ++i) {
812
813
814
815
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
816
    } else if (indptr_type == C_API_DTYPE_INT64) {
817
818
819
820
821
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
822
        for (int64_t i = start; i < end; ++i) {
823
824
825
826
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
827
828
829
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
830
831
832
833
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
834
  if (data_type == C_API_DTYPE_FLOAT32) {
835
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
836
    if (col_ptr_type == C_API_DTYPE_INT32) {
837
838
839
840
841
842
843
844
845
846
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
847
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
848
849
850
851
852
853
854
855
856
857
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
858
    } 
Guolin Ke's avatar
Guolin Ke committed
859
  } else if (data_type == C_API_DTYPE_FLOAT64) {
860
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
861
    if (col_ptr_type == C_API_DTYPE_INT32) {
862
863
864
865
866
867
868
869
870
871
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
872
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
873
874
875
876
877
878
879
880
881
882
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
883
884
885
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
886
887
}

888
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
889
890
891
892
893
894
895
896
897
898
899
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
900
}