"src/vscode:/vscode.git/clone" did not exist on "ab5ab8a7b6bfddb3a30b7fd02bf799772fce3a6f"
c_api.cpp 31.6 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18
#include <stdexcept>
wxchan's avatar
wxchan committed
19
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
20

Guolin Ke's avatar
Guolin Ke committed
21
22
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
23
24
25
26
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
27
28
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
29
30
31
  }

  Booster(const Dataset* train_data, 
wxchan's avatar
wxchan committed
32
33
34
    const char* parameters) {
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
Guolin Ke's avatar
Guolin Ke committed
35
36
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
37
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
38
39
        please use continued train with input score");
    }
wxchan's avatar
wxchan committed
40
41
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
    ConstructObjectAndTrainingMetrics(train_data);
Guolin Ke's avatar
Guolin Ke committed
42
    // initialize the boosting
wxchan's avatar
wxchan committed
43
    boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Guolin Ke's avatar
Guolin Ke committed
44
      Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
45
46
47
48
49
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
50
51
52
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
53

Guolin Ke's avatar
Guolin Ke committed
54
  }
55

wxchan's avatar
wxchan committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
    ConstructObjectAndTrainingMetrics(train_data_);
    // initialize the boosting
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_, 
      objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
  }

  void ResetConfig(const char* parameters) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
    config_.Set(param);
    if (param.size() == 1 && (param.count("learning_rate") || param.count("shrinkage_rate"))) {
      // only need to set learning rate
      boosting_->ResetShrinkageRate(config_.boosting_config.learning_rate);
    } else {
      ResetTrainingData(train_data_);
    }
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
      Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
  }
96
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
97
    std::lock_guard<std::mutex> lock(mutex_);
98
99
100
101
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
102
    std::lock_guard<std::mutex> lock(mutex_);
103
104
105
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
106
107
108
109
110
111
112
113
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

  void PrepareForPrediction(int num_iteration, int predict_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->SetNumIterationForPred(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
114
115
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
116
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
117
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
118
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
119
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
120
121
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
122
    }
Guolin Ke's avatar
Guolin Ke committed
123
    predictor_.reset(new Predictor(boosting_.get(), is_raw_score, is_predict_leaf));
Guolin Ke's avatar
Guolin Ke committed
124
125
  }

wxchan's avatar
wxchan committed
126
127
128
129
  void GetPredictAt(int data_idx, score_t* out_result, data_size_t* out_len) {
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

Guolin Ke's avatar
Guolin Ke committed
130
131
  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
132
133
  }

134
135
136
137
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

wxchan's avatar
wxchan committed
138
139
  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
140
  }
141

wxchan's avatar
wxchan committed
142
143
144
  std::string DumpModel() {
    return boosting_->DumpModel();
  }
145

wxchan's avatar
wxchan committed
146
147
148
149
150
151
152
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
153

wxchan's avatar
wxchan committed
154
155
156
157
158
159
160
161
162
163
164
165
166
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

  const Boosting* GetBoosting() const { return boosting_.get(); }
  
Guolin Ke's avatar
Guolin Ke committed
167
private:
168

wxchan's avatar
wxchan committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
      config_.objective_config));
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective functions");
    }
    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data->metadata(), train_data->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data->metadata(), train_data->num_data());
    }
  }

  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
193
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
194
195
196
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
197
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
198
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
199
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
200
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
201
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
202
  /*! \brief Using predictor for prediction task */
Guolin Ke's avatar
Guolin Ke committed
203
  std::unique_ptr<Predictor> predictor_;
wxchan's avatar
wxchan committed
204
205
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
206
207
208
};

}
Guolin Ke's avatar
Guolin Ke committed
209
210
211

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
212
DllExport const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
213
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
214
215
}

wxchan's avatar
wxchan committed
216
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
Guolin Ke's avatar
Guolin Ke committed
217
218
219
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
220
  API_BEGIN();
wxchan's avatar
wxchan committed
221
222
223
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
224
  DatasetLoader loader(io_config, nullptr, filename);
Guolin Ke's avatar
Guolin Ke committed
225
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
226
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
227
  } else {
Guolin Ke's avatar
Guolin Ke committed
228
229
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
      reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
230
  }
231
  API_END();
Guolin Ke's avatar
Guolin Ke committed
232
233
}

wxchan's avatar
wxchan committed
234
DllExport int LGBM_DatasetCreateFromMat(const void* data,
235
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
240
241
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
242
  API_BEGIN();
wxchan's avatar
wxchan committed
243
244
245
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
246
  std::unique_ptr<Dataset> ret;
247
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
248
249
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
250
251
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
252
    auto sample_indices = rand.Sample(nrow, sample_cnt);
253
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
254
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
255
      auto idx = sample_indices[i];
256
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
257
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
258
259
260
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
261
262
      }
    }
Guolin Ke's avatar
Guolin Ke committed
263
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
264
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
265
  } else {
wxchan's avatar
wxchan committed
266
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
267
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
268
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
269
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
275
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
279
  *out = ret.release();
280
  API_END();
281
282
}

wxchan's avatar
wxchan committed
283
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
284
  int indptr_type,
285
286
  const int32_t* indices,
  const void* data,
287
288
289
290
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
291
292
293
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
294
  API_BEGIN();
wxchan's avatar
wxchan committed
295
296
297
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
298
  std::unique_ptr<Dataset> ret;
299
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
300
301
302
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
303
304
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
305
306
307
308
309
310
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
311
312
313
314
315
316
317
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
318
          }
Guolin Ke's avatar
Guolin Ke committed
319
320
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
321
322
323
        }
      }
    }
324
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
325
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
326
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
327
  } else {
wxchan's avatar
wxchan committed
328
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
329
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
330
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
331
      io_config.is_enable_sparse);
332
333
334
335
336
337
338
339
340
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
341
  *out = ret.release();
342
  API_END();
343
344
}

wxchan's avatar
wxchan committed
345
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
346
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
347
348
  const int32_t* indices,
  const void* data,
349
350
351
352
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
353
354
355
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
356
  API_BEGIN();
wxchan's avatar
wxchan committed
357
358
359
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
360
  std::unique_ptr<Dataset> ret;
361
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
wxchan's avatar
wxchan committed
366
367
    Random rand(io_config.data_random_seed);
    const int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
368
369
370
371
372
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
373
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
374
    }
Guolin Ke's avatar
Guolin Ke committed
375
    DatasetLoader loader(io_config, nullptr, nullptr);
Guolin Ke's avatar
Guolin Ke committed
376
    ret.reset(loader.CostructFromSampleData(sample_values, sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
377
  } else {
wxchan's avatar
wxchan committed
378
    ret.reset(new Dataset(nrow, io_config.num_class));
Guolin Ke's avatar
Guolin Ke committed
379
    ret->CopyFeatureMapperFrom(
Guolin Ke's avatar
Guolin Ke committed
380
      reinterpret_cast<const Dataset*>(*reference),
wxchan's avatar
wxchan committed
381
      io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385
386
387
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
388
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
389
390
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
391
  *out = ret.release();
392
  API_END();
Guolin Ke's avatar
Guolin Ke committed
393
394
}

wxchan's avatar
wxchan committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
DllExport int LGBM_DatasetGetSubset(
  const DatesetHandle* handle,
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
  DatesetHandle* out) {
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  auto full_dataset = reinterpret_cast<const Dataset*>(*handle);
  auto ret = std::unique_ptr<Dataset>(
    full_dataset->Subset(used_row_indices,
      num_used_row_indices, 
      io_config.is_enable_sparse));
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
DllExport int LGBM_DatasetSetFeatureNames(
  DatesetHandle handle,
  const char** feature_names,
  int64_t num_feature_names) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
  for (int64_t i = 0; i < num_feature_names; ++i) {
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

429
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
430
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
431
  delete reinterpret_cast<Dataset*>(handle);
432
  API_END();
433
434
435
436
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
437
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
438
439
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
440
  API_END();
441
442
443
444
445
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
446
  int64_t num_element,
447
  int type) {
448
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
449
  auto dataset = reinterpret_cast<Dataset*>(handle);
450
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
451
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
452
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
453
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
454
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
455
  }
456
457
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
458
459
460
461
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
462
  int64_t* out_len,
463
464
  const void** out_ptr,
  int* out_type) {
465
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
466
  auto dataset = reinterpret_cast<Dataset*>(handle);
467
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
468
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
469
    *out_type = C_API_DTYPE_FLOAT32;
470
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
471
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
472
    *out_type = C_API_DTYPE_INT32;
473
    is_success = true;
474
  }
475
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
476
  if (*out_ptr == nullptr) { *out_len = 0; }
477
  API_END();
478
479
480
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
481
  int64_t* out) {
482
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
483
484
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
485
  API_END();
486
487
488
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
489
  int64_t* out) {
490
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
491
492
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
493
  API_END();
Guolin Ke's avatar
Guolin Ke committed
494
}
495
496
497
498
499
500
501


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const char* parameters,
  BoosterHandle* out) {
502
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
503
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
504
505
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
506
  API_END();
507
508
}

wxchan's avatar
wxchan committed
509
DllExport int LGBM_BoosterCreateFromModelfile(
510
  const char* filename,
wxchan's avatar
wxchan committed
511
  int64_t* out_num_iterations,
512
  BoosterHandle* out) {
513
  API_BEGIN();
wxchan's avatar
wxchan committed
514
515
516
517
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
  *out_num_iterations = static_cast<int64_t>(ret->GetBoosting()->NumberOfTotalModel()
    / ret->GetBoosting()->NumberOfClasses());
  *out = ret.release();
518
  API_END();
519
520
521
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
522
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
523
  delete reinterpret_cast<Booster*>(handle);
524
  API_END();
525
526
}

wxchan's avatar
wxchan committed
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
  BoosterHandle other_handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
  const DatesetHandle valid_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
  const DatesetHandle train_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

568
DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
569
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
570
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
571
572
573
574
575
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
576
  API_END();
577
578
579
580
581
582
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
583
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
584
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
585
586
587
588
589
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
590
  API_END();
591
592
}

wxchan's avatar
wxchan committed
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
DllExport int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

DllExport int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int64_t* out_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int64_t* out_len) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

/*!
* \brief Get number of eval
* \return total number of eval result
*/
DllExport int LGBM_BoosterGetEvalNames(BoosterHandle handle, int64_t* out_len, char** out_strs) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}


DllExport int LGBM_BoosterGetEval(BoosterHandle handle,
  int data_idx,
631
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
632
  float* out_results) {
633
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
634
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
635
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
636
  auto result_buf = boosting->GetEvalAt(data_idx);
637
  *out_len = static_cast<int64_t>(result_buf.size());
638
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
639
    (out_results)[i] = static_cast<float>(result_buf[i]);
640
  }
641
  API_END();
642
643
}

Guolin Ke's avatar
Guolin Ke committed
644
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
wxchan's avatar
wxchan committed
645
  int data_idx,
646
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
647
  float* out_result) {
648
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
649
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
650
  int len = 0;
wxchan's avatar
wxchan committed
651
  ref_booster->GetPredictAt(data_idx, out_result, &len);
652
  *out_len = static_cast<int64_t>(len);
653
  API_END();
Guolin Ke's avatar
Guolin Ke committed
654
655
}

656
657
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  const char* data_filename,
wxchan's avatar
wxchan committed
658
659
660
  int data_has_header,
  int predict_type,
  int64_t num_iteration,
661
  const char* result_filename) {
662
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
663
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
664
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
665
666
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
667
  API_END();
668
669
}

670
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
671
672
  const void* indptr,
  int indptr_type,
673
674
  const int32_t* indices,
  const void* data,
675
676
677
678
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
679
  int predict_type,
wxchan's avatar
wxchan committed
680
681
682
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
683
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
684
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
685
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
686

687
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
wxchan's avatar
wxchan committed
688
689
690
691
692
693
694
695
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
696
697
698
699
700
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
701
702
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
703
704
    }
  }
wxchan's avatar
wxchan committed
705
  *out_len = nrow * num_preb_in_one_row;
706
  API_END();
Guolin Ke's avatar
Guolin Ke committed
707
}
708
709
710

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
711
  int data_type,
712
713
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
714
  int is_row_major,
715
  int predict_type,
wxchan's avatar
wxchan committed
716
717
718
  int64_t num_iteration,
  int64_t* out_len,
  float* out_result) {
719
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
720
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
721
  ref_booster->PrepareForPrediction(static_cast<int>(num_iteration), predict_type);
Guolin Ke's avatar
Guolin Ke committed
722

723
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
wxchan's avatar
wxchan committed
724
725
726
727
728
729
730
731
  int num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
  if (predict_type == C_API_PREDICT_LEAF_INDEX) {
    if (num_iteration > 0) {
      num_preb_in_one_row *= static_cast<int>(num_iteration);
    } else {
      num_preb_in_one_row *= ref_booster->GetBoosting()->NumberOfTotalModel() / num_preb_in_one_row;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
732
733
734
735
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
wxchan's avatar
wxchan committed
736
737
    for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
      out_result[i * num_preb_in_one_row + j] = static_cast<float>(predicton_result[j]);
Guolin Ke's avatar
Guolin Ke committed
738
739
    }
  }
wxchan's avatar
wxchan committed
740
  *out_len = nrow * num_preb_in_one_row;
741
  API_END();
Guolin Ke's avatar
Guolin Ke committed
742
}
743
744

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
wxchan's avatar
wxchan committed
745
  int num_iteration,
Guolin Ke's avatar
Guolin Ke committed
746
  const char* filename) {
747
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
748
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
749
750
751
752
753
754
755
756
757
758
759
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
  int buffer_len,
  int64_t* out_len,
  char** out_str) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->DumpModel();
Guolin Ke's avatar
Guolin Ke committed
760
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
761
762
763
  if (*out_len <= buffer_len) {
    std::strcpy(*out_str, model.c_str());
  }
764
  API_END();
Guolin Ke's avatar
Guolin Ke committed
765
}
766

Guolin Ke's avatar
Guolin Ke committed
767
// ---- start of some help functions
768
769
770

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
771
  if (data_type == C_API_DTYPE_FLOAT32) {
772
773
774
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
775
        std::vector<double> ret(num_col);
776
777
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
778
          ret[i] = static_cast<double>(*(tmp_ptr + i));
779
780
781
782
783
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
784
        std::vector<double> ret(num_col);
785
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
786
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
787
788
789
790
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
791
  } else if (data_type == C_API_DTYPE_FLOAT64) {
792
793
794
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
795
        std::vector<double> ret(num_col);
796
797
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
798
          ret[i] = static_cast<double>(*(tmp_ptr + i));
799
800
801
802
803
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
804
        std::vector<double> ret(num_col);
805
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
806
          ret[i] = static_cast<double>(*(data_ptr + num_row * i + row_idx));
807
808
809
810
811
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
812
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
813
814
815
816
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
817
818
819
820
821
822
823
824
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
825
        }
Guolin Ke's avatar
Guolin Ke committed
826
827
828
      }
      return ret;
    };
829
  }
Guolin Ke's avatar
Guolin Ke committed
830
  return nullptr;
831
832
833
834
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
835
  if (data_type == C_API_DTYPE_FLOAT32) {
836
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
837
    if (indptr_type == C_API_DTYPE_INT32) {
838
839
840
841
842
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
843
        for (int64_t i = start; i < end; ++i) {
844
845
846
847
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
848
    } else if (indptr_type == C_API_DTYPE_INT64) {
849
850
851
852
853
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
854
        for (int64_t i = start; i < end; ++i) {
855
856
857
858
859
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
860
  } else if (data_type == C_API_DTYPE_FLOAT64) {
861
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
862
    if (indptr_type == C_API_DTYPE_INT32) {
863
864
865
866
867
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
868
        for (int64_t i = start; i < end; ++i) {
869
870
871
872
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
873
    } else if (indptr_type == C_API_DTYPE_INT64) {
874
875
876
877
878
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
879
        for (int64_t i = start; i < end; ++i) {
880
881
882
883
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
884
885
886
    } 
  } 
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
887
888
889
890
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
891
  if (data_type == C_API_DTYPE_FLOAT32) {
892
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
893
    if (col_ptr_type == C_API_DTYPE_INT32) {
894
895
896
897
898
899
900
901
902
903
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
904
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
905
906
907
908
909
910
911
912
913
914
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
915
    } 
Guolin Ke's avatar
Guolin Ke committed
916
  } else if (data_type == C_API_DTYPE_FLOAT64) {
917
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
918
    if (col_ptr_type == C_API_DTYPE_INT32) {
919
920
921
922
923
924
925
926
927
928
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
929
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
930
931
932
933
934
935
936
937
938
939
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
940
941
942
    } 
  } 
  throw std::runtime_error("unknown data type in ColumnFunctionFromCSC");
943
944
}

945
std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<int>& indices) {
946
947
948
949
950
951
952
953
954
955
956
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    }
  }
  return ret;
Guolin Ke's avatar
Guolin Ke committed
957
}