c_api.cpp 30.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
33
    std::lock_guard<std::mutex> lock(mutex_);
34
35
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
36
37
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
38
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
39
40
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
41
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
42
43
44
45
46
47
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
48
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
49
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
50
51
52
    boosting_->MergeFrom(other->boosting_.get());
  }

53
54
55
56
57
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
58
    std::lock_guard<std::mutex> lock(mutex_);
59
60
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
61
    // initialize the boosting
62
63
64
65
66
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
67
    std::lock_guard<std::mutex> lock(mutex_);
68
69
70
71
72
73
74
75
76
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
77
  }
Guolin Ke's avatar
Guolin Ke committed
78

79
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
80
    std::lock_guard<std::mutex> lock(mutex_);
81
82
83
84
85
86
87
88
89
90
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
91
  }
92
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
93
94
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(nullptr, nullptr, false);
95
96
97
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
98
99
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(gradients, hessians, false);
100
101
  }

Guolin Ke's avatar
Guolin Ke committed
102
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
103
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
104
105
106
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
107
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
108
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
109
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
110
111
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
112
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
113
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
114
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
115
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
116
117
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
118
    }
Guolin Ke's avatar
Guolin Ke committed
119
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
120
121
  }

Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
126
127
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
128
129
  }

130
131
132
133
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
134
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
135
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
136
  }
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
141
142
143
144
145

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
146
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
147
148
149
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
150
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
151
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
156
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
157
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
158
  
Guolin Ke's avatar
Guolin Ke committed
159
private:
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

184
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
185
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
186
187
188
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
189
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
190
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
191
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
192
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
193
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
194
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
195
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
196
197
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
198
199
200
};

}
Guolin Ke's avatar
Guolin Ke committed
201
202
203

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
204
DllExport const char* LGBM_GetLastError() {
205
  return LastErrorMsg().c_str();
Guolin Ke's avatar
Guolin Ke committed
206
207
208
209
210
211
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
212
  API_BEGIN();
213
214
215
216
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
217
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
218
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
219
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
220
  } else {
Guolin Ke's avatar
Guolin Ke committed
221
222
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
223
  }
224
  API_END();
Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
}

DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
  DatesetHandle* out) {
229
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
230
231
  OverallConfig config;
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
232
  *out = loader.LoadFromBinFile(filename, 0, 1);
233
  API_END();
Guolin Ke's avatar
Guolin Ke committed
234
235
236
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
237
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
244
  API_BEGIN();
245
246
247
248
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
249
  std::unique_ptr<Dataset> ret;
250
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
251
252
  if (reference == nullptr) {
    // sample data first
253
254
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
255
    auto sample_indices = rand.Sample(nrow, sample_cnt);
256
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
257
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
258
      auto idx = sample_indices[i];
259
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
260
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
261
262
263
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
264
265
      }
    }
Guolin Ke's avatar
Guolin Ke committed
266
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
267
  } else {
268
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
269
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
270
      reinterpret_cast<const Dataset*>(*reference),
271
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
272
273
274
275
276
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
277
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
278
279
280
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
281
  *out = ret.release();
282
  API_END();
283
284
}

285
286
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
287
288
  const int32_t* indices,
  const void* data,
289
290
291
292
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
293
294
295
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
296
  API_BEGIN();
297
298
299
300
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
301
  std::unique_ptr<Dataset> ret;
302
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
303
304
305
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
306
307
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
308
309
310
311
312
313
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318
319
320
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
321
          }
Guolin Ke's avatar
Guolin Ke committed
322
323
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
324
325
326
        }
      }
    }
327
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
328
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
329
  } else {
330
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
331
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
332
      reinterpret_cast<const Dataset*>(*reference),
333
      io_config.is_enable_sparse);
334
335
336
337
338
339
340
341
342
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
343
  *out = ret.release();
344
  API_END();
345
346
}

347
348
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
349
350
  const int32_t* indices,
  const void* data,
351
352
353
354
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
355
356
357
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
358
  API_BEGIN();
359
360
361
362
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
363
  std::unique_ptr<Dataset> ret;
364
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
365
366
367
368
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
369
370
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
371
372
373
374
375
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
376
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
377
    }
Guolin Ke's avatar
Guolin Ke committed
378
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
379
  } else {
380
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
381
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
382
      reinterpret_cast<const Dataset*>(*reference),
383
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
388
389
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
390
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
391
392
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
393
  *out = ret.release();
394
  API_END();
Guolin Ke's avatar
Guolin Ke committed
395
396
}

397
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
398
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
399
  delete reinterpret_cast<Dataset*>(handle);
400
  API_END();
401
402
403
404
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
405
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
406
407
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
408
  API_END();
409
410
411
412
413
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
414
  int64_t num_element,
415
  int type) {
416
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
417
  auto dataset = reinterpret_cast<Dataset*>(handle);
418
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
419
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
420
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
421
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
422
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
423
  }
424
425
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
426
427
428
429
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
430
  int64_t* out_len,
431
432
  const void** out_ptr,
  int* out_type) {
433
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
434
  auto dataset = reinterpret_cast<Dataset*>(handle);
435
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
436
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
437
    *out_type = C_API_DTYPE_FLOAT32;
438
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
439
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
440
    *out_type = C_API_DTYPE_INT32;
441
    is_success = true;
442
  }
443
444
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
445
  API_END();
446
447
448
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
449
  int64_t* out) {
450
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
451
452
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
453
  API_END();
454
455
456
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
457
  int64_t* out) {
458
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
459
460
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
461
  API_END();
Guolin Ke's avatar
Guolin Ke committed
462
}
463
464
465
466
467
468
469


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
470
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
471
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
472
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
473
  *out = ret.release();
474
  API_END();
475
476
}

477
DllExport int LGBM_BoosterCreateFromModelfile(
478
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
479
  int64_t* num_total_model,
480
  BoosterHandle* out) {
481
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
482
483
484
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
  *out = ret.release();
485
  API_END();
486
487
488
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
489
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
490
  delete reinterpret_cast<Booster*>(handle);
491
  API_END();
492
493
}

Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
498
499
500
501
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

521
522
523
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
524
  ref_booster->ResetConfig(parameters);
525
526
527
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
528
529
530
531
532
533
534
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

535
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
536
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
537
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
538
539
540
541
542
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
543
  API_END();
544
545
546
547
548
549
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
550
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
551
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
552
553
554
555
556
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
557
  API_END();
558
559
}

560
561
562
563
564
565
566
567
568
569
570
571
572
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
588
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
589
590
591
592
593
594
595
596
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
597
  int data,
598
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
599
  float* out_results) {
600
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
601
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
602
603
  auto boosting = ref_booster->GetBoosting();
  auto result_buf = boosting->GetEvalAt(data);
604
  *out_len = static_cast<int64_t>(result_buf.size());
605
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
606
    (out_results)[i] = static_cast<float>(result_buf[i]);
607
  }
608
  API_END();
609
610
}

Guolin Ke's avatar
Guolin Ke committed
611
612
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
  int data,
613
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
614
  float* out_result) {
615
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
616
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
617
  int len = 0;
Guolin Ke's avatar
Guolin Ke committed
618
  ref_booster->GetPredictAt(data, out_result, &len);
619
  *out_len = static_cast<int64_t>(len);
620
  API_END();
Guolin Ke's avatar
Guolin Ke committed
621
622
}

623
624
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
625
626
627
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
628
  const char* result_filename) {
629
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
630
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
631
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
632
633
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
634
  API_END();
635
636
}

637
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
638
639
  const void* indptr,
  int indptr_type,
640
641
  const int32_t* indices,
  const void* data,
642
643
644
645
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
646
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
647
648
649
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
650
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
651
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
652
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
653

654
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
655
656
657
658
659
660
661
662
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
663
664
665
666
667
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
668
669
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
670
671
    }
  }
Guolin Ke's avatar
Guolin Ke committed
672
  *out_len = nrow * num_preb_in_one_row;
673
  API_END();
Guolin Ke's avatar
Guolin Ke committed
674
}
675
676
677

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
678
  int data_type,
679
680
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
681
  int is_row_major,
682
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
683
684
685
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
686
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
687
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
688
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
689

690
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
691
692
693
694
695
696
697
698
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
699
700
701
702
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
703
704
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
705
706
    }
  }
Guolin Ke's avatar
Guolin Ke committed
707
  *out_len = nrow * num_preb_in_one_row;
708
  API_END();
Guolin Ke's avatar
Guolin Ke committed
709
}
710
711

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
712
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
713
  const char* filename) {
714
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
715
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
716
  ref_booster->SaveModelToFile(num_iteration, filename);
717
  API_END();
Guolin Ke's avatar
Guolin Ke committed
718
}
719

Guolin Ke's avatar
Guolin Ke committed
720
// ---- start of some help functions
721
722
723

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
724
  if (data_type == C_API_DTYPE_FLOAT32) {
725
726
727
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
728
        std::vector<double> ret(num_col);
729
730
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
731
          ret[i] = static_cast<double>(*(tmp_ptr + i));
732
733
734
735
736
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
737
        std::vector<double> ret(num_col);
738
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
739
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
740
741
742
743
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
744
  } else if (data_type == C_API_DTYPE_FLOAT64) {
745
746
747
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
748
        std::vector<double> ret(num_col);
749
750
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
751
          ret[i] = static_cast<double>(*(tmp_ptr + i));
752
753
754
755
756
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
757
        std::vector<double> ret(num_col);
758
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
759
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
760
761
762
763
764
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
765
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
766
767
768
769
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
770
771
772
773
774
775
776
777
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
778
        }
Guolin Ke's avatar
Guolin Ke committed
779
780
781
      }
      return ret;
    };
782
  }
Guolin Ke's avatar
Guolin Ke committed
783
  return nullptr;
784
785
786
787
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
788
  if (data_type == C_API_DTYPE_FLOAT32) {
789
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
790
    if (indptr_type == C_API_DTYPE_INT32) {
791
792
793
794
795
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
796
        for (int64_t i = start; i < end; ++i) {
797
798
799
800
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
801
    } else if (indptr_type == C_API_DTYPE_INT64) {
802
803
804
805
806
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
807
        for (int64_t i = start; i < end; ++i) {
808
809
810
811
812
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
813
  } else if (data_type == C_API_DTYPE_FLOAT64) {
814
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
815
    if (indptr_type == C_API_DTYPE_INT32) {
816
817
818
819
820
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
821
        for (int64_t i = start; i < end; ++i) {
822
823
824
825
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
826
    } else if (indptr_type == C_API_DTYPE_INT64) {
827
828
829
830
831
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
832
        for (int64_t i = start; i < end; ++i) {
833
834
835
836
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
837
838
839
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
840
841
842
843
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
844
  if (data_type == C_API_DTYPE_FLOAT32) {
845
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
846
    if (col_ptr_type == C_API_DTYPE_INT32) {
847
848
849
850
851
852
853
854
855
856
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
857
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
858
859
860
861
862
863
864
865
866
867
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
868
    } 
Guolin Ke's avatar
Guolin Ke committed
869
  } else if (data_type == C_API_DTYPE_FLOAT64) {
870
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
871
    if (col_ptr_type == C_API_DTYPE_INT32) {
872
873
874
875
876
877
878
879
880
881
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
882
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
883
884
885
886
887
888
889
890
891
892
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
893
894
895
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
896
897
}

898
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
899
900
901
902
903
904
905
906
907
908
909
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
910
}