c_api.cpp 30.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
33
34
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
40
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
41
42
43
44
45
46
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
47
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
48
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
49
50
51
    boosting_->MergeFrom(other->boosting_.get());
  }

52
53
54
55
56
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
57
    std::lock_guard<std::mutex> lock(mutex_);
58
59
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
60
    // initialize the boosting
61
62
63
64
65
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
66
    std::lock_guard<std::mutex> lock(mutex_);
67
68
69
70
71
72
73
74
75
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
76
  }
Guolin Ke's avatar
Guolin Ke committed
77

78
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
79
    std::lock_guard<std::mutex> lock(mutex_);
80
81
82
83
84
85
86
87
88
89
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
90
  }
91
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
92
93
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(nullptr, nullptr, false);
94
95
96
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
97
98
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(gradients, hessians, false);
99
100
  }

Guolin Ke's avatar
Guolin Ke committed
101
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
102
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
106
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
107
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
108
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
109
110
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
111
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
112
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
113
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
114
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
115
116
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
117
    }
Guolin Ke's avatar
Guolin Ke committed
118
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
119
120
  }

Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
125
126
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
127
128
  }

129
130
131
132
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
133
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
134
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
135
  }
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140
141
142
143
144

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
145
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
149
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
150
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
156
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
157
  
Guolin Ke's avatar
Guolin Ke committed
158
private:
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

183
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
184
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
188
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
189
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
190
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
191
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
192
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
193
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
194
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
195
196
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
};

}
Guolin Ke's avatar
Guolin Ke committed
200
201
202

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
203
DllExport const char* LGBM_GetLastError() {
204
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
205
206
207
208
209
210
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
211
  API_BEGIN();
212
213
214
215
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
216
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
217
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
218
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
219
  } else {
Guolin Ke's avatar
Guolin Ke committed
220
221
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
222
  }
223
  API_END();
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
}

DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
  DatesetHandle* out) {
228
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
229
230
  OverallConfig config;
  DatasetLoader loader(config.io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
231
  *out = loader.LoadFromBinFile(filename, 0, 1);
232
  API_END();
Guolin Ke's avatar
Guolin Ke committed
233
234
235
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
236
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
241
242
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
243
  API_BEGIN();
244
245
246
247
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
248
  std::unique_ptr<Dataset> ret;
249
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
250
251
  if (reference == nullptr) {
    // sample data first
252
253
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
254
    auto sample_indices = rand.Sample(nrow, sample_cnt);
255
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
256
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
257
      auto idx = sample_indices[i];
258
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
259
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
260
261
262
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
263
264
      }
    }
Guolin Ke's avatar
Guolin Ke committed
265
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
266
  } else {
267
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
268
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
269
      reinterpret_cast<const Dataset*>(*reference),
270
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
271
272
273
274
275
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
276
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
277
278
279
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
280
  *out = ret.release();
281
  API_END();
282
283
}

284
285
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
286
287
  const int32_t* indices,
  const void* data,
288
289
290
291
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
292
293
294
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
295
  API_BEGIN();
296
297
298
299
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
300
  std::unique_ptr<Dataset> ret;
301
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
302
303
304
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
305
306
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
307
308
309
310
311
312
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
317
318
319
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
320
          }
Guolin Ke's avatar
Guolin Ke committed
321
322
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
323
324
325
        }
      }
    }
326
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
327
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
328
  } else {
329
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
330
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
331
      reinterpret_cast<const Dataset*>(*reference),
332
      io_config.is_enable_sparse);
333
334
335
336
337
338
339
340
341
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
342
  *out = ret.release();
343
  API_END();
344
345
}

346
347
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
348
349
  const int32_t* indices,
  const void* data,
350
351
352
353
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
354
355
356
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
357
  API_BEGIN();
358
359
360
361
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
362
  std::unique_ptr<Dataset> ret;
363
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
364
365
366
367
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
368
369
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
370
371
372
373
374
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
375
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
376
    }
Guolin Ke's avatar
Guolin Ke committed
377
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
378
  } else {
379
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
380
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
381
      reinterpret_cast<const Dataset*>(*reference),
382
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
383
384
385
386
387
388
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
389
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
390
391
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
392
  *out = ret.release();
393
  API_END();
Guolin Ke's avatar
Guolin Ke committed
394
395
}

396
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
397
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
398
  delete reinterpret_cast<Dataset*>(handle);
399
  API_END();
400
401
402
403
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
404
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
405
406
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
407
  API_END();
408
409
410
411
412
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
413
  int64_t num_element,
414
  int type) {
415
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
416
  auto dataset = reinterpret_cast<Dataset*>(handle);
417
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
418
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
419
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
420
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
421
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
422
  }
423
424
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
425
426
427
428
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
429
  int64_t* out_len,
430
431
  const void** out_ptr,
  int* out_type) {
432
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
433
  auto dataset = reinterpret_cast<Dataset*>(handle);
434
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
435
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
436
    *out_type = C_API_DTYPE_FLOAT32;
437
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
438
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
439
    *out_type = C_API_DTYPE_INT32;
440
    is_success = true;
441
  }
442
443
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
444
  API_END();
445
446
447
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
448
  int64_t* out) {
449
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
450
451
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
452
  API_END();
453
454
455
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
456
  int64_t* out) {
457
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
458
459
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
460
  API_END();
Guolin Ke's avatar
Guolin Ke committed
461
}
462
463
464
465
466
467
468


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
469
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
470
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
471
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
472
  *out = ret.release();
473
  API_END();
474
475
}

476
DllExport int LGBM_BoosterCreateFromModelfile(
477
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
478
  int64_t* num_total_model,
479
  BoosterHandle* out) {
480
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
481
482
483
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
  *out = ret.release();
484
  API_END();
485
486
487
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
488
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
489
  delete reinterpret_cast<Booster*>(handle);
490
  API_END();
491
492
}

Guolin Ke's avatar
Guolin Ke committed
493
494
495
496
497
498
499
500
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

520
521
522
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
523
  ref_booster->ResetConfig(parameters);
524
525
526
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
531
532
533
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

534
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
535
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
536
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
537
538
539
540
541
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
542
  API_END();
543
544
545
546
547
548
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
549
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
550
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
551
552
553
554
555
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
556
  API_END();
557
558
}

559
560
561
562
563
564
565
566
567
568
569
570
571
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
587
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
588
589
590
591
592
593
594
595
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
596
  int data_idx,
597
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
598
  float* out_results) {
599
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
600
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
601
  auto boosting = ref_booster->GetBoosting();
Guolin Ke's avatar
typo  
Guolin Ke committed
602
  auto result_buf = boosting->GetEvalAt(data_idx);
603
  *out_len = static_cast<int64_t>(result_buf.size());
604
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
605
    (out_results)[i] = static_cast<float>(result_buf[i]);
606
  }
607
  API_END();
608
609
}

Guolin Ke's avatar
Guolin Ke committed
610
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
611
  int data_idx,
612
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
613
  float* out_result) {
614
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
615
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
616
  int len = 0;
Guolin Ke's avatar
typo  
Guolin Ke committed
617
  ref_booster->GetPredictAt(data_idx, out_result, &len);
618
  *out_len = static_cast<int64_t>(len);
619
  API_END();
Guolin Ke's avatar
Guolin Ke committed
620
621
}

622
623
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
624
625
626
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
627
  const char* result_filename) {
628
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
629
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
630
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
631
632
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
633
  API_END();
634
635
}

636
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
637
638
  const void* indptr,
  int indptr_type,
639
640
  const int32_t* indices,
  const void* data,
641
642
643
644
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
645
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
646
647
648
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
649
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
650
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
651
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
652

653
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
654
655
656
657
658
659
660
661
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
662
663
664
665
666
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
667
668
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
669
670
    }
  }
Guolin Ke's avatar
Guolin Ke committed
671
  *out_len = nrow * num_preb_in_one_row;
672
  API_END();
Guolin Ke's avatar
Guolin Ke committed
673
}
674
675
676

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
677
  int data_type,
678
679
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
680
  int is_row_major,
681
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
682
683
684
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
685
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
686
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
687
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
688

689
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
690
691
692
693
694
695
696
697
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
698
699
700
701
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
702
703
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
704
705
    }
  }
Guolin Ke's avatar
Guolin Ke committed
706
  *out_len = nrow * num_preb_in_one_row;
707
  API_END();
Guolin Ke's avatar
Guolin Ke committed
708
}
709
710

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
711
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
712
  const char* filename) {
713
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
714
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
715
  ref_booster->SaveModelToFile(num_iteration, filename);
716
  API_END();
Guolin Ke's avatar
Guolin Ke committed
717
}
718

Guolin Ke's avatar
Guolin Ke committed
719
// ---- start of some help functions
720
721
722

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
723
  if (data_type == C_API_DTYPE_FLOAT32) {
724
725
726
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
727
        std::vector<double> ret(num_col);
728
729
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
730
          ret[i] = static_cast<double>(*(tmp_ptr + i));
731
732
733
734
735
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
736
        std::vector<double> ret(num_col);
737
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
738
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
739
740
741
742
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
743
  } else if (data_type == C_API_DTYPE_FLOAT64) {
744
745
746
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
747
        std::vector<double> ret(num_col);
748
749
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
750
          ret[i] = static_cast<double>(*(tmp_ptr + i));
751
752
753
754
755
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
756
        std::vector<double> ret(num_col);
757
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
758
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
759
760
761
762
763
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
764
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
765
766
767
768
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
769
770
771
772
773
774
775
776
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
777
        }
Guolin Ke's avatar
Guolin Ke committed
778
779
780
      }
      return ret;
    };
781
  }
Guolin Ke's avatar
Guolin Ke committed
782
  return nullptr;
783
784
785
786
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
787
  if (data_type == C_API_DTYPE_FLOAT32) {
788
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
789
    if (indptr_type == C_API_DTYPE_INT32) {
790
791
792
793
794
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
795
        for (int64_t i = start; i < end; ++i) {
796
797
798
799
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
800
    } else if (indptr_type == C_API_DTYPE_INT64) {
801
802
803
804
805
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
806
        for (int64_t i = start; i < end; ++i) {
807
808
809
810
811
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
812
  } else if (data_type == C_API_DTYPE_FLOAT64) {
813
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
814
    if (indptr_type == C_API_DTYPE_INT32) {
815
816
817
818
819
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
820
        for (int64_t i = start; i < end; ++i) {
821
822
823
824
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
825
    } else if (indptr_type == C_API_DTYPE_INT64) {
826
827
828
829
830
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
831
        for (int64_t i = start; i < end; ++i) {
832
833
834
835
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
836
837
838
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
839
840
841
842
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
843
  if (data_type == C_API_DTYPE_FLOAT32) {
844
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
845
    if (col_ptr_type == C_API_DTYPE_INT32) {
846
847
848
849
850
851
852
853
854
855
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
856
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
857
858
859
860
861
862
863
864
865
866
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
867
    } 
Guolin Ke's avatar
Guolin Ke committed
868
  } else if (data_type == C_API_DTYPE_FLOAT64) {
869
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
870
    if (col_ptr_type == C_API_DTYPE_INT32) {
871
872
873
874
875
876
877
878
879
880
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
881
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
882
883
884
885
886
887
888
889
890
891
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
892
893
894
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
895
896
}

897
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
898
899
900
901
902
903
904
905
906
907
908
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
909
}