c_api.cpp 30.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
Guolin Ke's avatar
Guolin Ke committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
32
    const char* parameters) {
33
34
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
40
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
41
42
43
44
45
46
    ConstructObjectAndTrainingMetrics(train_data);
    // initialize the boosting
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

Guolin Ke's avatar
Guolin Ke committed
47
  void MergeFrom(const Booster* other) {
Guolin Ke's avatar
Guolin Ke committed
48
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
49
50
51
    boosting_->MergeFrom(other->boosting_.get());
  }

52
53
54
55
56
  ~Booster() {

  }

  void ResetTrainingData(const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
57
    std::lock_guard<std::mutex> lock(mutex_);
58
59
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
60
    // initialize the boosting
61
62
63
64
65
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
66
    std::lock_guard<std::mutex> lock(mutex_);
67
68
69
70
71
72
73
74
75
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    ResetTrainingData(train_data_);
76
  }
Guolin Ke's avatar
Guolin Ke committed
77

78
  void AddValidData(const Dataset* valid_data) {
Guolin Ke's avatar
Guolin Ke committed
79
    std::lock_guard<std::mutex> lock(mutex_);
80
81
82
83
84
85
86
87
88
89
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
Guolin Ke's avatar
Guolin Ke committed
90
  }
91
  bool TrainOneIter() {
Guolin Ke's avatar
Guolin Ke committed
92
93
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(nullptr, nullptr, false);
94
95
96
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
Guolin Ke's avatar
Guolin Ke committed
97
98
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->TrainOneIter(gradients, hessians, false);
99
100
  }

Guolin Ke's avatar
Guolin Ke committed
101
  void RollbackOneIter() {
Guolin Ke's avatar
Guolin Ke committed
102
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
106
  void PrepareForPrediction(int num_iteration, int predict_type) {
Guolin Ke's avatar
Guolin Ke committed
107
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
108
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
109
110
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
111
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
112
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
113
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
114
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
115
116
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
117
    }
Guolin Ke's avatar
Guolin Ke committed
118
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
119
120
  }

Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
125
126
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
127
128
  }

129
130
131
132
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
133
  void SaveModelToFile(int num_iteration, const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
134
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
135
  }
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
140
141
142
143
144

  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }

Guolin Ke's avatar
Guolin Ke committed
145
  int GetEvalNames(char** out_strs) const {
Guolin Ke's avatar
Guolin Ke committed
146
147
148
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
149
        std::strcpy(out_strs[idx], name.c_str());
Guolin Ke's avatar
Guolin Ke committed
150
        ++idx;
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
      }
    }
    return idx;
  }

Guolin Ke's avatar
Guolin Ke committed
156
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
157
  
Guolin Ke's avatar
Guolin Ke committed
158
private:
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

183
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
184
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
188
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
189
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
190
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
191
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
192
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
193
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
194
  std::unique_ptr<Predictor> predictor_;
Guolin Ke's avatar
Guolin Ke committed
195
196
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
197
198
199
};

}
Guolin Ke's avatar
Guolin Ke committed
200
201
202

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
203
DllExport const char* LGBM_GetLastError() {
204
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
205
206
}

Guolin Ke's avatar
Guolin Ke committed
207
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
208
209
210
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
211
  API_BEGIN();
212
213
214
215
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
216
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
217
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
218
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
219
  } else {
Guolin Ke's avatar
Guolin Ke committed
220
221
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
222
  }
223
  API_END();
Guolin Ke's avatar
Guolin Ke committed
224
225
}

Guolin Ke's avatar
Guolin Ke committed
226
DllExport int LGBM_DatasetCreateFromMat(const void* data,
227
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
232
233
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
234
  API_BEGIN();
235
236
237
238
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
239
  std::unique_ptr<Dataset> ret;
240
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
241
242
  if (reference == nullptr) {
    // sample data first
243
244
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
245
    auto sample_indices = rand.Sample(nrow, sample_cnt);
246
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
247
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
248
      auto idx = sample_indices[i];
249
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
250
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
251
252
253
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
254
255
      }
    }
Guolin Ke's avatar
Guolin Ke committed
256
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
257
  } else {
258
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
259
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
260
      reinterpret_cast<const Dataset*>(*reference),
261
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
266
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
267
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
268
269
270
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
271
  *out = ret.release();
272
  API_END();
273
274
}

Guolin Ke's avatar
Guolin Ke committed
275
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
276
  int indptr_type,
277
278
  const int32_t* indices,
  const void* data,
279
280
281
282
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
283
284
285
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
286
  API_BEGIN();
287
288
289
290
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
291
  std::unique_ptr<Dataset> ret;
292
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
293
294
295
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
296
297
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
298
299
300
301
302
303
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
304
305
306
307
308
309
310
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
311
          }
Guolin Ke's avatar
Guolin Ke committed
312
313
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
314
315
316
        }
      }
    }
317
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
318
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
319
  } else {
320
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
321
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
322
      reinterpret_cast<const Dataset*>(*reference),
323
      io_config.is_enable_sparse);
324
325
326
327
328
329
330
331
332
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
333
  *out = ret.release();
334
  API_END();
335
336
}

Guolin Ke's avatar
Guolin Ke committed
337
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
338
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
339
340
  const int32_t* indices,
  const void* data,
341
342
343
344
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
345
346
347
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
348
  API_BEGIN();
349
350
351
352
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr);
Guolin Ke's avatar
Guolin Ke committed
353
  std::unique_ptr<Dataset> ret;
354
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
359
360
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
361
362
363
364
365
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
366
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
367
    }
Guolin Ke's avatar
Guolin Ke committed
368
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
369
  } else {
370
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
371
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
372
      reinterpret_cast<const Dataset*>(*reference),
373
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
374
375
376
377
378
379
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
380
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
381
382
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
383
  *out = ret.release();
384
  API_END();
Guolin Ke's avatar
Guolin Ke committed
385
386
}

Guolin Ke's avatar
Guolin Ke committed
387
DllExport int LGBM_DatasetGetSubset(
Guolin Ke's avatar
Guolin Ke committed
388
  const DatesetHandle* handle,
Guolin Ke's avatar
Guolin Ke committed
389
  const int32_t* used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
390
  int32_t num_used_row_indices,
Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
395
396
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
397
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
Guolin Ke's avatar
Guolin Ke committed
398
399
400
401
402
403
404
405
406
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

407
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
408
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
409
  delete reinterpret_cast<Dataset*>(handle);
410
  API_END();
411
412
413
414
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
415
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
416
417
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
418
  API_END();
419
420
421
422
423
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
424
  int64_t num_element,
425
  int type) {
426
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
427
  auto dataset = reinterpret_cast<Dataset*>(handle);
428
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
429
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
430
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
431
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
432
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
433
  }
434
435
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
436
437
438
439
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
440
  int64_t* out_len,
441
442
  const void** out_ptr,
  int* out_type) {
443
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
444
  auto dataset = reinterpret_cast<Dataset*>(handle);
445
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
446
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
447
    *out_type = C_API_DTYPE_FLOAT32;
448
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
449
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
450
    *out_type = C_API_DTYPE_INT32;
451
    is_success = true;
452
  }
453
454
  if (!is_success) { throw std::runtime_error("Field not found"); }
  if (*out_ptr == nullptr) { *out_len = 0; }
455
  API_END();
456
457
458
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
459
  int64_t* out) {
460
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
461
462
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
463
  API_END();
464
465
466
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
467
  int64_t* out) {
468
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
469
470
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
471
  API_END();
Guolin Ke's avatar
Guolin Ke committed
472
}
473
474
475
476
477
478
479


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
480
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
481
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
482
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
483
  *out = ret.release();
484
  API_END();
485
486
}

487
DllExport int LGBM_BoosterCreateFromModelfile(
488
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
489
  int64_t* out_num_iterations,
490
  BoosterHandle* out) {
491
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
492
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
493
494
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
Guolin Ke's avatar
Guolin Ke committed
495
  *out = ret.release();
496
  API_END();
497
498
499
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
500
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
501
  delete reinterpret_cast<Booster*>(handle);
502
  API_END();
503
504
}

Guolin Ke's avatar
Guolin Ke committed
505
506
507
508
509
510
511
512
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

532
533
534
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
535
  ref_booster->ResetConfig(parameters);
536
537
538
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
543
544
545
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

546
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
547
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
548
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
549
550
551
552
553
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
554
  API_END();
555
556
557
558
559
560
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
561
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
562
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
563
564
565
566
567
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
568
  API_END();
569
570
}

571
572
573
574
575
576
577
578
579
580
581
582
583
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
Guolin Ke's avatar
Guolin Ke committed
599
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
Guolin Ke's avatar
Guolin Ke committed
600
601
602
603
604
605
606
607
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
608
  int data_idx,
609
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
610
  float* out_results) {
611
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
612
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
613
  auto boosting = ref_booster->GetBoosting();
Guolin Ke's avatar
typo  
Guolin Ke committed
614
  auto result_buf = boosting->GetEvalAt(data_idx);
615
  *out_len = static_cast<int64_t>(result_buf.size());
616
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
617
    (out_results)[i] = static_cast<float>(result_buf[i]);
618
  }
619
  API_END();
620
621
}

Guolin Ke's avatar
Guolin Ke committed
622
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
Guolin Ke's avatar
typo  
Guolin Ke committed
623
  int data_idx,
624
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
625
  float* out_result) {
626
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
627
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
628
  int len = 0;
Guolin Ke's avatar
typo  
Guolin Ke committed
629
  ref_booster->GetPredictAt(data_idx, out_result, &len);
630
  *out_len = static_cast<int64_t>(len);
631
  API_END();
Guolin Ke's avatar
Guolin Ke committed
632
633
}

634
635
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
636
637
638
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
639
  const char* result_filename) {
640
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
641
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
642
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
643
644
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
645
  API_END();
646
647
}

648
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
649
650
  const void* indptr,
  int indptr_type,
651
652
  const int32_t* indices,
  const void* data,
653
654
655
656
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
657
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
658
659
660
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
661
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
662
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
663
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
664

665
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
666
667
668
669
670
671
672
673
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
674
675
676
677
678
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
679
680
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
681
682
    }
  }
Guolin Ke's avatar
Guolin Ke committed
683
  *out_len = nrow * num_preb_in_one_row;
684
  API_END();
Guolin Ke's avatar
Guolin Ke committed
685
}
686
687
688

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
689
  int data_type,
690
691
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
692
  int is_row_major,
693
  int predict_type,
Guolin Ke's avatar
Guolin Ke committed
694
695
696
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
697
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
698
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
699
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
700

701
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
702
703
704
705
706
707
708
709
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
710
711
712
713
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
714
715
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
716
717
    }
  }
Guolin Ke's avatar
Guolin Ke committed
718
  *out_len = nrow * num_preb_in_one_row;
719
  API_END();
Guolin Ke's avatar
Guolin Ke committed
720
}
721
722

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
Guolin Ke's avatar
Guolin Ke committed
723
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
724
  const char* filename) {
725
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
726
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
727
  ref_booster->SaveModelToFile(num_iteration, filename);
728
  API_END();
Guolin Ke's avatar
Guolin Ke committed
729
}
730

Guolin Ke's avatar
Guolin Ke committed
731
// ---- start of some help functions
732
733
734

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
735
  if (data_type == C_API_DTYPE_FLOAT32) {
736
737
738
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
739
        std::vector<double> ret(num_col);
740
741
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
742
          ret[i] = static_cast<double>(*(tmp_ptr + i));
743
744
745
746
747
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
748
        std::vector<double> ret(num_col);
749
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
750
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
751
752
753
754
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
755
  } else if (data_type == C_API_DTYPE_FLOAT64) {
756
757
758
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
759
        std::vector<double> ret(num_col);
760
761
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
762
          ret[i] = static_cast<double>(*(tmp_ptr + i));
763
764
765
766
767
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
768
        std::vector<double> ret(num_col);
769
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
770
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
771
772
773
774
775
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
776
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
777
778
779
780
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
781
782
783
784
785
786
787
788
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
789
        }
Guolin Ke's avatar
Guolin Ke committed
790
791
792
      }
      return ret;
    };
793
  }
Guolin Ke's avatar
Guolin Ke committed
794
  return nullptr;
795
796
797
798
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
799
  if (data_type == C_API_DTYPE_FLOAT32) {
800
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
801
    if (indptr_type == C_API_DTYPE_INT32) {
802
803
804
805
806
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
807
        for (int64_t i = start; i < end; ++i) {
808
809
810
811
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
812
    } else if (indptr_type == C_API_DTYPE_INT64) {
813
814
815
816
817
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
818
        for (int64_t i = start; i < end; ++i) {
819
820
821
822
823
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
824
  } else if (data_type == C_API_DTYPE_FLOAT64) {
825
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
826
    if (indptr_type == C_API_DTYPE_INT32) {
827
828
829
830
831
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
832
        for (int64_t i = start; i < end; ++i) {
833
834
835
836
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
837
    } else if (indptr_type == C_API_DTYPE_INT64) {
838
839
840
841
842
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
843
        for (int64_t i = start; i < end; ++i) {
844
845
846
847
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
848
849
850
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
851
852
853
854
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
855
  if (data_type == C_API_DTYPE_FLOAT32) {
856
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
857
    if (col_ptr_type == C_API_DTYPE_INT32) {
858
859
860
861
862
863
864
865
866
867
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
868
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
869
870
871
872
873
874
875
876
877
878
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
879
    } 
Guolin Ke's avatar
Guolin Ke committed
880
  } else if (data_type == C_API_DTYPE_FLOAT64) {
881
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
882
    if (col_ptr_type == C_API_DTYPE_INT32) {
883
884
885
886
887
888
889
890
891
892
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
893
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
894
895
896
897
898
899
900
901
902
903
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
904
905
906
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
907
908
}

909
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
910
911
912
913
914
915
916
917
918
919
920
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
921
}