c_api.cpp 30.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
33
34
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
40
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
41
42
43
44
45
46
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
47
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
48
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
49
50
51
    boosting_->MergeFrom(other->boosting_.get());
  }

52
53
54
55
56
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
57
    std::lock_guard<std::mutex> lock(mutex_);
58
59
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
60
    // initialize the boosting
61
62
63
64
65
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
66
    std::lock_guard<std::mutex> lock(mutex_);
67
68
69
70
71
72
73
74
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
75
76
77
78
79
80
    if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
      // only need to set learning rate
      boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
    } else {
      ResetTrainingData(train_data_);
    }
81
  }
Guolin Ke's avatar
Guolin Ke committed
82

83
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
84
    std::lock_guard<std::mutex> lock(mutex_);
85
86
87
88
89
90
91
92
93
94
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
95
  }
96
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
97
98
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(nullptr, nullptr, false);
99
100
101
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
102
103
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(gradients, hessians, false);
104
105
  }

Guolin Ke's avatar
Guolin Ke committed
106
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
107
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
108
109
110
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
111
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
112
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
113
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
114
115
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
116
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
117
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
118
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
119
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
120
121
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
122
    }
Guolin Ke's avatar
Guolin Ke committed
123
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
124
125
  }

Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
130
131
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
132
133
  }

134
135
136
137
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
138
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
139
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
140
  }
Guolin Ke's avatar
Guolin Ke committed
141
142
143
144
145
146
147
148
149

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
150
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
151
152
153
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
154
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
155
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
161
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
162
  
Guolin Ke's avatar
Guolin Ke committed
163
private:
Guolin Ke's avatar
Guolin Ke committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

188
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
189
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
190
191
192
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
193
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
194
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
195
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
196
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
197
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
198
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
199
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
200
201
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
202
203
204
};

}
Guolin Ke's avatar
Guolin Ke committed
205
206
207

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
208
DllExport const char* LGBM_GetLastError() {
209
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
210
211
}

Guolin Ke's avatar
Guolin Ke committed
212
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
213
214
215
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
216
  API_BEGIN();
217
218
219
220
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
221
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
222
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
223
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
224
  } else {
Guolin Ke's avatar
Guolin Ke committed
225
226
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
227
  }
228
  API_END();
Guolin Ke's avatar
Guolin Ke committed
229
230
}

Guolin Ke's avatar
Guolin Ke committed
231
DllExport int LGBM_DatasetCreateFromMat(const void* data,
232
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
237
238
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
239
  API_BEGIN();
240
241
242
243
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
244
  std::unique_ptr<Dataset> ret;
245
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
246
247
  if (reference == nullptr) {
    // sample data first
248
249
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
250
    auto sample_indices = rand.Sample(nrow, sample_cnt);
251
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
252
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
253
      auto idx = sample_indices[i];
254
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
255
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
256
257
258
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
259
260
      }
    }
Guolin Ke's avatar
Guolin Ke committed
261
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
262
  } else {
263
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
264
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
265
      reinterpret_cast<const Dataset*>(*reference),
266
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
271
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
272
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
273
274
275
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
276
  *out = ret.release();
277
  API_END();
278
279
}

Guolin Ke's avatar
Guolin Ke committed
280
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
281
  int indptr_type,
282
283
  const int32_t* indices,
  const void* data,
284
285
286
287
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
288
289
290
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
291
  API_BEGIN();
292
293
294
295
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
296
  std::unique_ptr<Dataset> ret;
297
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
298
299
300
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
301
302
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
303
304
305
306
307
308
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
315
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
316
          }
Guolin Ke's avatar
Guolin Ke committed
317
318
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
319
320
321
        }
      }
    }
322
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
323
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
324
  } else {
325
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
326
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
327
      reinterpret_cast<const Dataset*>(*reference),
328
      io_config.is_enable_sparse);
329
330
331
332
333
334
335
336
337
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
338
  *out = ret.release();
339
  API_END();
340
341
}

Guolin Ke's avatar
Guolin Ke committed
342
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
343
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
344
345
  const int32_t* indices,
  const void* data,
346
347
348
349
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
350
351
352
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
353
  API_BEGIN();
354
355
356
357
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
358
  std::unique_ptr<Dataset> ret;
359
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
360
361
362
363
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
364
365
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
366
367
368
369
370
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
371
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
372
    }
Guolin Ke's avatar
Guolin Ke committed
373
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
374
  } else {
375
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
376
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
377
      reinterpret_cast<const Dataset*>(*reference),
378
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
379
380
381
382
383
384
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
385
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
386
387
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
388
  *out = ret.release();
389
  API_END();
Guolin Ke's avatar
Guolin Ke committed
390
391
}

Guolin Ke's avatar
Guolin Ke committed
392
DllExport int LGBM_DatasetGetSubset(
Guolin Ke's avatar
Guolin Ke committed
393
  const DatesetHandle* handle,
Guolin Ke's avatar
Guolin Ke committed
394
  const int32_t* used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
395
  int32_t num_used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
396
397
398
399
400
401
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
402
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
Guolin Ke's avatar
Guolin Ke committed
403
404
405
406
407
408
409
410
411
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

412
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
413
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
414
  delete reinterpret_cast<Dataset*>(handle);
415
  API_END();
416
417
418
419
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
420
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
421
422
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
423
  API_END();
424
425
426
427
428
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
429
  int64_t num_element,
430
  int type) {
431
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
432
  auto dataset = reinterpret_cast<Dataset*>(handle);
433
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
434
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
435
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
436
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
437
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
438
  }
439
440
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
441
442
443
444
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
445
  int64_t* out_len,
446
447
  const void** out_ptr,
  int* out_type) {
448
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
449
  auto dataset = reinterpret_cast<Dataset*>(handle);
450
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
451
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
452
    *out_type = C_API_DTYPE_FLOAT32;
453
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
454
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
455
    *out_type = C_API_DTYPE_INT32;
456
    is_success = true;
457
  }
458
459
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
460
  API_END();
461
462
463
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
464
  int64_t* out) {
465
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
466
467
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
468
  API_END();
469
470
471
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
472
  int64_t* out) {
473
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
474
475
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
476
  API_END();
Guolin Ke's avatar
Guolin Ke committed
477
}
478
479
480
481
482
483
484


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
485
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
486
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
487
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
488
  *out = ret.release();
489
  API_END();
490
491
}

492
DllExport int LGBM_BoosterCreateFromModelfile(
493
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
494
  int64_t* out_num_iterations,
495
  BoosterHandle* out) {
496
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
497
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
498
499
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
Guolin Ke's avatar
Guolin Ke committed
500
  *out = ret.release();
501
  API_END();
502
503
504
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
505
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
506
  delete reinterpret_cast<Booster*>(handle);
507
  API_END();
508
509
}

Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
514
515
516
517
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

537
538
539
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
540
  ref_booster->ResetConfig(parameters);
541
542
543
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
544
545
546
547
548
549
550
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

551
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
552
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
553
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
554
555
556
557
558
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
559
  API_END();
560
561
562
563
564
565
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
566
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
567
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
568
569
570
571
572
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
573
  API_END();
574
575
}

576
577
578
579
580
581
582
583
584
585
586
587
588
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
604
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
605
606
607
608
609
610
611
612
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
613
  int data_idx,
614
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
615
  float* out_results) {
616
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
617
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
618
  auto boosting = ref_booster->GetBoosting();
Guolin Ke's avatar
typo  
Guolin Ke committed
619
  auto result_buf = boosting->GetEvalAt(data_idx);
620
  *out_len = static_cast<int64_t>(result_buf.size());
621
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
622
    (out_results)[i] = static_cast<float>(result_buf[i]);
623
  }
624
  API_END();
625
626
}

Guolin Ke's avatar
Guolin Ke committed
627
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
628
  int data_idx,
629
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
630
  float* out_result) {
631
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
632
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
633
  int len = 0;
Guolin Ke's avatar
typo  
Guolin Ke committed
634
  ref_booster->GetPredictAt(data_idx, out_result, &len);
635
  *out_len = static_cast<int64_t>(len);
636
  API_END();
Guolin Ke's avatar
Guolin Ke committed
637
638
}

639
640
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
641
642
643
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
644
  const char* result_filename) {
645
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
646
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
647
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
648
649
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
650
  API_END();
651
652
}

653
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
654
655
  const void* indptr,
  int indptr_type,
656
657
  const int32_t* indices,
  const void* data,
658
659
660
661
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
662
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
663
664
665
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
666
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
667
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
668
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
669

670
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
671
672
673
674
675
676
677
678
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
679
680
681
682
683
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
684
685
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
686
687
    }
  }
Guolin Ke's avatar
Guolin Ke committed
688
  *out_len = nrow * num_preb_in_one_row;
689
  API_END();
Guolin Ke's avatar
Guolin Ke committed
690
}
691
692
693

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
694
  int data_type,
695
696
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
697
  int is_row_major,
698
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
699
700
701
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
702
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
703
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
704
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
705

706
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
707
708
709
710
711
712
713
714
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
715
716
717
718
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
719
720
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
721
722
    }
  }
Guolin Ke's avatar
Guolin Ke committed
723
  *out_len = nrow * num_preb_in_one_row;
724
  API_END();
Guolin Ke's avatar
Guolin Ke committed
725
}
726
727

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
728
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
729
  const char* filename) {
730
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
731
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
732
  ref_booster->SaveModelToFile(num_iteration, filename);
733
  API_END();
Guolin Ke's avatar
Guolin Ke committed
734
}
735

Guolin Ke's avatar
Guolin Ke committed
736
// ---- start of some help functions
737
738
739

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
740
  if (data_type == C_API_DTYPE_FLOAT32) {
741
742
743
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
744
        std::vector<double> ret(num_col);
745
746
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
747
          ret[i] = static_cast<double>(*(tmp_ptr + i));
748
749
750
751
752
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
753
        std::vector<double> ret(num_col);
754
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
755
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
756
757
758
759
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
760
  } else if (data_type == C_API_DTYPE_FLOAT64) {
761
762
763
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
764
        std::vector<double> ret(num_col);
765
766
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
767
          ret[i] = static_cast<double>(*(tmp_ptr + i));
768
769
770
771
772
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
773
        std::vector<double> ret(num_col);
774
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
775
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
776
777
778
779
780
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
781
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
782
783
784
785
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
786
787
788
789
790
791
792
793
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
794
        }
Guolin Ke's avatar
Guolin Ke committed
795
796
797
      }
      return ret;
    };
798
  }
Guolin Ke's avatar
Guolin Ke committed
799
  return nullptr;
800
801
802
803
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
804
  if (data_type == C_API_DTYPE_FLOAT32) {
805
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
806
    if (indptr_type == C_API_DTYPE_INT32) {
807
808
809
810
811
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
812
        for (int64_t i = start; i < end; ++i) {
813
814
815
816
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
817
    } else if (indptr_type == C_API_DTYPE_INT64) {
818
819
820
821
822
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
823
        for (int64_t i = start; i < end; ++i) {
824
825
826
827
828
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
829
  } else if (data_type == C_API_DTYPE_FLOAT64) {
830
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
831
    if (indptr_type == C_API_DTYPE_INT32) {
832
833
834
835
836
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
837
        for (int64_t i = start; i < end; ++i) {
838
839
840
841
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
842
    } else if (indptr_type == C_API_DTYPE_INT64) {
843
844
845
846
847
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
848
        for (int64_t i = start; i < end; ++i) {
849
850
851
852
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
853
854
855
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
856
857
858
859
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
860
  if (data_type == C_API_DTYPE_FLOAT32) {
861
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
862
    if (col_ptr_type == C_API_DTYPE_INT32) {
863
864
865
866
867
868
869
870
871
872
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
873
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
874
875
876
877
878
879
880
881
882
883
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
884
    } 
Guolin Ke's avatar
Guolin Ke committed
885
  } else if (data_type == C_API_DTYPE_FLOAT64) {
886
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
887
    if (col_ptr_type == C_API_DTYPE_INT32) {
888
889
890
891
892
893
894
895
896
897
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
898
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
899
900
901
902
903
904
905
906
907
908
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
909
910
911
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
912
913
}

914
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
915
916
917
918
919
920
921
922
923
924
925
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
926
}