c_api.cpp 30.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
33
34
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
40
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
41
42
43
44
45
46
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
47
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
48
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
49
50
51
    boosting_->MergeFrom(other->boosting_.get());
  }

52
53
54
55
56
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
57
    std::lock_guard<std::mutex> lock(mutex_);
58
59
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
60
    // initialize the boosting
61
62
63
64
65
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
66
    std::lock_guard<std::mutex> lock(mutex_);
67
68
69
70
71
72
73
74
75
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
76
  }
Guolin Ke's avatar
Guolin Ke committed
77

78
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
79
    std::lock_guard<std::mutex> lock(mutex_);
80
81
82
83
84
85
86
87
88
89
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
90
  }
91
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
92
93
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(nullptr, nullptr, false);
94
95
96
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
97
98
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(gradients, hessians, false);
99
100
  }

Guolin Ke's avatar
Guolin Ke committed
101
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
102
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
106
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
107
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
108
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
109
110
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
111
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
112
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
113
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
114
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
115
116
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
117
    }
Guolin Ke's avatar
Guolin Ke committed
118
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
119
120
  }

Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
125
126
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
127
128
  }

129
130
131
132
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
133
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
134
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
135
  }
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140
141
142
143
144

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
145
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
149
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
150
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
156
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
157
  
Guolin Ke's avatar
Guolin Ke committed
158
private:
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

183
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
184
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
188
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
189
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
190
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
191
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
192
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
193
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
194
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
195
196
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
};

}
Guolin Ke's avatar
Guolin Ke committed
200
201
202

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
203
DllExport const char* LGBM_GetLastError() {
204
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
205
206
}

Guolin Ke's avatar
Guolin Ke committed
207
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
208
209
210
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
211
  API_BEGIN();
212
213
214
215
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
216
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
217
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
218
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
219
  } else {
Guolin Ke's avatar
Guolin Ke committed
220
221
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
222
  }
223
  API_END();
Guolin Ke's avatar
Guolin Ke committed
224
225
}

Guolin Ke's avatar
Guolin Ke committed
226
DllExport int LGBM_DatasetCreateFromMat(const void* data,
227
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
232
233
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
234
  API_BEGIN();
235
236
237
238
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
239
  std::unique_ptr<Dataset> ret;
240
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
241
242
  if (reference == nullptr) {
    // sample data first
243
244
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
245
    auto sample_indices = rand.Sample(nrow, sample_cnt);
246
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
247
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
248
      auto idx = sample_indices[i];
249
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
250
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
251
252
253
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
254
255
      }
    }
Guolin Ke's avatar
Guolin Ke committed
256
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
257
  } else {
258
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
259
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
260
      reinterpret_cast<const Dataset*>(*reference),
261
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
266
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
267
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
268
269
270
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
271
  *out = ret.release();
272
  API_END();
273
274
}

Guolin Ke's avatar
Guolin Ke committed
275
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
276
  int indptr_type,
277
278
  const int32_t* indices,
  const void* data,
279
280
281
282
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
283
284
285
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
286
  API_BEGIN();
287
288
289
290
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
291
  std::unique_ptr<Dataset> ret;
292
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
293
294
295
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
296
297
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
298
299
300
301
302
303
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
304
305
306
307
308
309
310
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
311
          }
Guolin Ke's avatar
Guolin Ke committed
312
313
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
314
315
316
        }
      }
    }
317
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
318
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
319
  } else {
320
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
321
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
322
      reinterpret_cast<const Dataset*>(*reference),
323
      io_config.is_enable_sparse);
324
325
326
327
328
329
330
331
332
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
333
  *out = ret.release();
334
  API_END();
335
336
}

Guolin Ke's avatar
Guolin Ke committed
337
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
338
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
339
340
  const int32_t* indices,
  const void* data,
341
342
343
344
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
345
346
347
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
348
  API_BEGIN();
349
350
351
352
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
353
  std::unique_ptr<Dataset> ret;
354
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
359
360
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
366
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
367
    }
Guolin Ke's avatar
Guolin Ke committed
368
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
369
  } else {
370
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
371
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
372
      reinterpret_cast<const Dataset*>(*reference),
373
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
374
375
376
377
378
379
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
380
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
381
382
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
383
  *out = ret.release();
384
  API_END();
Guolin Ke's avatar
Guolin Ke committed
385
386
}

Guolin Ke's avatar
Guolin Ke committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
DllExport int LGBM_DatasetGetSubset(
  const DatesetHandle* full_data,
  const int32_t* used_row_indices,
  const int32_t num_used_row_indices,
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  auto full_dataset = reinterpret_cast<const Dataset*>(*full_data);
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

407
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
408
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
409
  delete reinterpret_cast<Dataset*>(handle);
410
  API_END();
411
412
413
414
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
415
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
416
417
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
418
  API_END();
419
420
421
422
423
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
424
  int64_t num_element,
425
  int type) {
426
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
427
  auto dataset = reinterpret_cast<Dataset*>(handle);
428
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
429
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
430
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
431
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
432
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
433
  }
434
435
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
436
437
438
439
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
440
  int64_t* out_len,
441
442
  const void** out_ptr,
  int* out_type) {
443
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
444
  auto dataset = reinterpret_cast<Dataset*>(handle);
445
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
446
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
447
    *out_type = C_API_DTYPE_FLOAT32;
448
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
449
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
450
    *out_type = C_API_DTYPE_INT32;
451
    is_success = true;
452
  }
453
454
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
455
  API_END();
456
457
458
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
459
  int64_t* out) {
460
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
461
462
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
463
  API_END();
464
465
466
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
467
  int64_t* out) {
468
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
469
470
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
471
  API_END();
Guolin Ke's avatar
Guolin Ke committed
472
}
473
474
475
476
477
478
479


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
480
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
481
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
482
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
483
  *out = ret.release();
484
  API_END();
485
486
}

487
DllExport int LGBM_BoosterCreateFromModelfile(
488
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
489
  int64_t* num_total_model,
490
  BoosterHandle* out) {
491
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
492
493
494
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *num_total_model = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel());
  *out = ret.release();
495
  API_END();
496
497
498
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
499
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
500
  delete reinterpret_cast<Booster*>(handle);
501
  API_END();
502
503
}

Guolin Ke's avatar
Guolin Ke committed
504
505
506
507
508
509
510
511
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

531
532
533
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
534
  ref_booster->ResetConfig(parameters);
535
536
537
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
538
539
540
541
542
543
544
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

545
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
546
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
547
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
548
549
550
551
552
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
553
  API_END();
554
555
556
557
558
559
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
560
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
561
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
562
563
564
565
566
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
567
  API_END();
568
569
}

570
571
572
573
574
575
576
577
578
579
580
581
582
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
598
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
599
600
601
602
603
604
605
606
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
607
  int data_idx,
608
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
609
  float* out_results) {
610
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
611
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
612
  auto boosting = ref_booster->GetBoosting();
Guolin Ke's avatar
typo  
Guolin Ke committed
613
  auto result_buf = boosting->GetEvalAt(data_idx);
614
  *out_len = static_cast<int64_t>(result_buf.size());
615
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
616
    (out_results)[i] = static_cast<float>(result_buf[i]);
617
  }
618
  API_END();
619
620
}

Guolin Ke's avatar
Guolin Ke committed
621
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
622
  int data_idx,
623
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
624
  float* out_result) {
625
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
626
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
627
  int len = 0;
Guolin Ke's avatar
typo  
Guolin Ke committed
628
  ref_booster->GetPredictAt(data_idx, out_result, &len);
629
  *out_len = static_cast<int64_t>(len);
630
  API_END();
Guolin Ke's avatar
Guolin Ke committed
631
632
}

633
634
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
635
636
637
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
638
  const char* result_filename) {
639
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
640
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
641
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
642
643
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
644
  API_END();
645
646
}

647
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
648
649
  const void* indptr,
  int indptr_type,
650
651
  const int32_t* indices,
  const void* data,
652
653
654
655
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
656
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
657
658
659
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
660
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
661
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
662
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
663

664
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
665
666
667
668
669
670
671
672
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
673
674
675
676
677
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
678
679
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
680
681
    }
  }
Guolin Ke's avatar
Guolin Ke committed
682
  *out_len = nrow * num_preb_in_one_row;
683
  API_END();
Guolin Ke's avatar
Guolin Ke committed
684
}
685
686
687

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
688
  int data_type,
689
690
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
691
  int is_row_major,
692
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
693
694
695
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
696
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
697
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
698
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
699

700
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
701
702
703
704
705
706
707
708
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
709
710
711
712
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
713
714
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
715
716
    }
  }
Guolin Ke's avatar
Guolin Ke committed
717
  *out_len = nrow * num_preb_in_one_row;
718
  API_END();
Guolin Ke's avatar
Guolin Ke committed
719
}
720
721

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
722
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
723
  const char* filename) {
724
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
725
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
726
  ref_booster->SaveModelToFile(num_iteration, filename);
727
  API_END();
Guolin Ke's avatar
Guolin Ke committed
728
}
729

Guolin Ke's avatar
Guolin Ke committed
730
// ---- start of some help functions
731
732
733

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
734
  if (data_type == C_API_DTYPE_FLOAT32) {
735
736
737
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
738
        std::vector<double> ret(num_col);
739
740
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
741
          ret[i] = static_cast<double>(*(tmp_ptr + i));
742
743
744
745
746
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
747
        std::vector<double> ret(num_col);
748
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
749
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
750
751
752
753
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
754
  } else if (data_type == C_API_DTYPE_FLOAT64) {
755
756
757
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
758
        std::vector<double> ret(num_col);
759
760
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
761
          ret[i] = static_cast<double>(*(tmp_ptr + i));
762
763
764
765
766
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
767
        std::vector<double> ret(num_col);
768
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
769
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
770
771
772
773
774
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
775
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
776
777
778
779
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
780
781
782
783
784
785
786
787
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
788
        }
Guolin Ke's avatar
Guolin Ke committed
789
790
791
      }
      return ret;
    };
792
  }
Guolin Ke's avatar
Guolin Ke committed
793
  return nullptr;
794
795
796
797
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
798
  if (data_type == C_API_DTYPE_FLOAT32) {
799
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
800
    if (indptr_type == C_API_DTYPE_INT32) {
801
802
803
804
805
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
806
        for (int64_t i = start; i < end; ++i) {
807
808
809
810
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
811
    } else if (indptr_type == C_API_DTYPE_INT64) {
812
813
814
815
816
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
817
        for (int64_t i = start; i < end; ++i) {
818
819
820
821
822
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
823
  } else if (data_type == C_API_DTYPE_FLOAT64) {
824
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
825
    if (indptr_type == C_API_DTYPE_INT32) {
826
827
828
829
830
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
831
        for (int64_t i = start; i < end; ++i) {
832
833
834
835
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
836
    } else if (indptr_type == C_API_DTYPE_INT64) {
837
838
839
840
841
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
842
        for (int64_t i = start; i < end; ++i) {
843
844
845
846
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
847
848
849
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
850
851
852
853
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
854
  if (data_type == C_API_DTYPE_FLOAT32) {
855
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
856
    if (col_ptr_type == C_API_DTYPE_INT32) {
857
858
859
860
861
862
863
864
865
866
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
867
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
868
869
870
871
872
873
874
875
876
877
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
878
    } 
Guolin Ke's avatar
Guolin Ke committed
879
  } else if (data_type == C_API_DTYPE_FLOAT64) {
880
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
881
    if (col_ptr_type == C_API_DTYPE_INT32) {
882
883
884
885
886
887
888
889
890
891
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
892
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
893
894
895
896
897
898
899
900
901
902
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
903
904
905
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
906
907
}

908
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
909
910
911
912
913
914
915
916
917
918
919
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
920
}