c_api.cpp 89.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20

21
22
23
24
25
26
27
28
#include <string>
#include <cstdio>
#include <functional>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <vector>

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

49
50
51
52
53
54
55
56
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

Guolin Ke's avatar
Guolin Ke committed
57
  SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) {
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
Guolin Ke's avatar
Guolin Ke committed
74
    predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
75
                                   early_stop_, early_stop_freq_, early_stop_margin_));
Guolin Ke's avatar
Guolin Ke committed
76
    num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
77
    predict_function = predictor_->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
78
    num_total_model_ = boosting->NumberOfTotalModel();
79
  }
80

81
  ~SingleRowPredictor() {}
82

Guolin Ke's avatar
Guolin Ke committed
83
  bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) {
84
85
86
87
88
    return early_stop_ == config.pred_early_stop &&
      early_stop_freq_ == config.pred_early_stop_freq &&
      early_stop_margin_ == config.pred_early_stop_margin &&
      iter_ == iter &&
      num_total_model_ == boosting->NumberOfTotalModel();
89
  }
Guolin Ke's avatar
Guolin Ke committed
90

91
92
93
94
95
96
97
98
99
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
100
class Booster {
Nikita Titov's avatar
Nikita Titov committed
101
 public:
Guolin Ke's avatar
Guolin Ke committed
102
  explicit Booster(const char* filename) {
103
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
104
105
  }

Guolin Ke's avatar
Guolin Ke committed
106
  Booster(const Dataset* train_data,
107
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
108
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
109
    config_.Set(param);
110
111
112
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
113
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
114
    if (config_.input_model.size() > 0) {
115
116
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
117
    }
Guolin Ke's avatar
Guolin Ke committed
118

Guolin Ke's avatar
Guolin Ke committed
119
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
120

121
122
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
123
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
124
    if (config_.tree_learner == std::string("feature")) {
125
      Log::Fatal("Do not support feature parallel in c api");
126
    }
Guolin Ke's avatar
Guolin Ke committed
127
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
128
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
129
      config_.tree_learner = "serial";
130
    }
Guolin Ke's avatar
Guolin Ke committed
131
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
132
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
133
134
135
136
137
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
138
139
140
141
  }

  ~Booster() {
  }
142

143
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
144
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
145
146
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
151
152
153
154
155
156
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
157
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
158
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
159
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
160
161
162
163
164
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
165
166
167
168
169
170
171
172
173
174
175
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
176
177
  }

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
  static void CheckDatasetResetConfig(
      const Config& old_config,
      const std::unordered_map<std::string, std::string>& new_param) {
    Config new_config;
    new_config.Set(new_param);
    if (new_param.count("data_random_seed") &&
        new_config.data_random_seed != old_config.data_random_seed) {
      Log::Fatal("Cannot change data_random_seed after constructed Dataset handle.");
    }
    if (new_param.count("max_bin") &&
        new_config.max_bin != old_config.max_bin) {
      Log::Fatal("Cannot change max_bin after constructed Dataset handle.");
    }
    if (new_param.count("max_bin_by_feature") &&
        new_config.max_bin_by_feature != old_config.max_bin_by_feature) {
      Log::Fatal(
          "Cannot change max_bin_by_feature after constructed Dataset handle.");
    }
    if (new_param.count("bin_construct_sample_cnt") &&
        new_config.bin_construct_sample_cnt !=
            old_config.bin_construct_sample_cnt) {
      Log::Fatal(
          "Cannot change bin_construct_sample_cnt after constructed Dataset "
          "handle.");
    }
    if (new_param.count("min_data_in_bin") &&
        new_config.min_data_in_bin != old_config.min_data_in_bin) {
      Log::Fatal(
          "Cannot change min_data_in_bin after constructed Dataset handle.");
    }
    if (new_param.count("use_missing") &&
        new_config.use_missing != old_config.use_missing) {
      Log::Fatal("Cannot change use_missing after constructed Dataset handle.");
    }
    if (new_param.count("zero_as_missing") &&
        new_config.zero_as_missing != old_config.zero_as_missing) {
      Log::Fatal(
          "Cannot change zero_as_missing after constructed Dataset handle.");
    }
    if (new_param.count("categorical_feature") &&
        new_config.categorical_feature != old_config.categorical_feature) {
      Log::Fatal(
          "Cannot change categorical_feature after constructed Dataset "
          "handle.");
    }
    if (new_param.count("feature_pre_filter") &&
        new_config.feature_pre_filter != old_config.feature_pre_filter) {
      Log::Fatal(
          "Cannot change feature_pre_filter after constructed Dataset handle.");
    }
    if (new_param.count("is_enable_sparse") &&
        new_config.is_enable_sparse != old_config.is_enable_sparse) {
      Log::Fatal(
          "Cannot change is_enable_sparse after constructed Dataset handle.");
    }
    if (new_param.count("pre_partition") &&
        new_config.pre_partition != old_config.pre_partition) {
      Log::Fatal(
          "Cannot change pre_partition after constructed Dataset handle.");
    }
    if (new_param.count("enable_bundle") &&
        new_config.enable_bundle != old_config.enable_bundle) {
      Log::Fatal(
          "Cannot change enable_bundle after constructed Dataset handle.");
    }
    if (new_param.count("header") && new_config.header != old_config.header) {
      Log::Fatal("Cannot change header after constructed Dataset handle.");
    }
    if (new_param.count("two_round") &&
        new_config.two_round != old_config.two_round) {
      Log::Fatal("Cannot change two_round after constructed Dataset handle.");
    }
    if (new_param.count("label_column") &&
        new_config.label_column != old_config.label_column) {
      Log::Fatal(
          "Cannot change label_column after constructed Dataset handle.");
    }
    if (new_param.count("weight_column") &&
        new_config.weight_column != old_config.weight_column) {
      Log::Fatal(
          "Cannot change weight_column after constructed Dataset handle.");
    }
    if (new_param.count("group_column") &&
        new_config.group_column != old_config.group_column) {
      Log::Fatal(
          "Cannot change group_column after constructed Dataset handle.");
    }
    if (new_param.count("ignore_column") &&
        new_config.ignore_column != old_config.ignore_column) {
      Log::Fatal(
          "Cannot change ignore_column after constructed Dataset handle.");
    }
    if (new_param.count("forcedbins_filename")) {
      Log::Fatal("Cannot change forced bins after constructed Dataset handle.");
    }
    if (new_param.count("min_data_in_leaf") &&
        new_config.min_data_in_leaf < old_config.min_data_in_leaf &&
        old_config.feature_pre_filter) {
      Log::Fatal(
          "Reducing `min_data_in_leaf` with `feature_pre_filter=true` may "
          "cause unexpected behaviour "
          "for features that were pre-filtered by the larger "
          "`min_data_in_leaf`.\n"
          "You need to set `feature_pre_filter=false` to dynamically change "
          "the `min_data_in_leaf`.");
    }
  }

wxchan's avatar
wxchan committed
286
  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
287
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
288
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
289
    if (param.count("num_class")) {
290
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
291
    }
Guolin Ke's avatar
Guolin Ke committed
292
293
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
294
    }
Guolin Ke's avatar
Guolin Ke committed
295
    if (param.count("metric")) {
296
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
297
    }
298
299
    CheckDatasetResetConfig(config_, param);

Guolin Ke's avatar
Guolin Ke committed
300
    config_.Set(param);
301

302
303
304
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
305
306
307

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
308
309
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
310
311
312
313
314
315
316
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
317
318
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
319
    }
Guolin Ke's avatar
Guolin Ke committed
320

Guolin Ke's avatar
Guolin Ke committed
321
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
322
323
324
325
326
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
327
328
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
329
330
331
332
333
334
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
335
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
336
  }
Guolin Ke's avatar
Guolin Ke committed
337

338
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
339
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
340
    return boosting_->TrainOneIter(nullptr, nullptr);
341
342
  }

Guolin Ke's avatar
Guolin Ke committed
343
344
345
346
347
348
349
350
351
352
353
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

354
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
355
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
356
    return boosting_->TrainOneIter(gradients, hessians);
357
358
  }

wxchan's avatar
wxchan committed
359
360
361
362
363
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

364
  void PredictSingleRow(int num_iteration, int predict_type, int ncol,
365
366
367
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
368
369
370
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
371
    }
372
    std::lock_guard<std::mutex> lock(mutex_);
373
    if (single_row_predictor_[predict_type].get() == nullptr ||
Guolin Ke's avatar
Guolin Ke committed
374
375
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
376
                                                                       config, num_iteration));
377
378
379
    }
    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
380
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
381

382
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
383
384
  }

385
  Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) {
386
387
388
    if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
                 "You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
389
    }
Guolin Ke's avatar
Guolin Ke committed
390
391
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
392
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
393
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
394
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
395
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
396
      is_raw_score = true;
397
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
398
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
399
400
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
401
    }
Guolin Ke's avatar
Guolin Ke committed
402

Guolin Ke's avatar
Guolin Ke committed
403
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
404
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    return predictor;
  }

  void Predict(int num_iteration, int predict_type, int nrow, int ncol,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
    bool is_predict_leaf = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    }
Guolin Ke's avatar
Guolin Ke committed
421
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
422
    auto pred_fun = predictor.GetPredictFunction();
423
424
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
425
    for (int i = 0; i < nrow; ++i) {
426
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
427
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
428
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
429
      pred_fun(one_row, pred_wrt_ptr);
430
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
431
    }
432
    OMP_THROW_EX();
433
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
434
435
  }

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
  void PredictSparse(int num_iteration, int predict_type, int64_t nrow, int ncol,
                     std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
                     const Config& config, int64_t* out_elements_size,
                     std::vector<std::vector<std::unordered_map<int, double>>>* agg_ptr,
                     int32_t** out_indices, void** out_data, int data_type,
                     bool* is_data_float32_ptr, int num_matrices) {
    auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
    auto pred_sparse_fun = predictor.GetPredictSparseFunction();
    std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr;
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int64_t i = 0; i < nrow; ++i) {
      OMP_LOOP_EX_BEGIN();
      auto one_row = get_row_fun(i);
      agg[i] = std::vector<std::unordered_map<int, double>>(num_matrices);
      pred_sparse_fun(one_row, &agg[i]);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
    // calculate the nonzero data and indices size
    int64_t elements_size = 0;
    for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
      auto row_vector = agg[i];
      for (int j = 0; j < static_cast<int>(row_vector.size()); ++j) {
        elements_size += static_cast<int64_t>(row_vector[j].size());
      }
    }
    *out_elements_size = elements_size;
    *is_data_float32_ptr = false;
    // allocate data and indices arrays
    if (data_type == C_API_DTYPE_FLOAT32) {
      *out_data = new float[elements_size];
      *is_data_float32_ptr = true;
    } else if (data_type == C_API_DTYPE_FLOAT64) {
      *out_data = new double[elements_size];
    } else {
      Log::Fatal("Unknown data type in PredictSparse");
      return;
    }
    *out_indices = new int32_t[elements_size];
  }

  void PredictSparseCSR(int num_iteration, int predict_type, int64_t nrow, int ncol,
                        std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
                        const Config& config,
                        int64_t* out_len, void** out_indptr, int indptr_type,
                        int32_t** out_indices, void** out_data, int data_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
    int num_matrices = boosting_->NumModelPerIteration();
    bool is_indptr_int32 = false;
    bool is_data_float32 = false;
    int64_t indptr_size = (nrow + 1) * num_matrices;
    if (indptr_type == C_API_DTYPE_INT32) {
      *out_indptr = new int32_t[indptr_size];
      is_indptr_int32 = true;
    } else if (indptr_type == C_API_DTYPE_INT64) {
      *out_indptr = new int64_t[indptr_size];
    } else {
      Log::Fatal("Unknown indptr type in PredictSparseCSR");
      return;
    }
    // aggregated per row feature contribution results
    std::vector<std::vector<std::unordered_map<int, double>>> agg(nrow);
    int64_t elements_size = 0;
    PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg,
                  out_indices, out_data, data_type, &is_data_float32, num_matrices);
    std::vector<int> row_sizes(num_matrices * nrow);
    std::vector<int64_t> row_matrix_offsets(num_matrices * nrow);
    int64_t row_vector_cnt = 0;
    for (int m = 0; m < num_matrices; ++m) {
      for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
        auto row_vector = agg[i];
        auto row_vector_size = row_vector[m].size();
        // keep track of the row_vector sizes for parallelization
        row_sizes[row_vector_cnt] = static_cast<int>(row_vector_size);
        if (i == 0) {
          row_matrix_offsets[row_vector_cnt] = 0;
        } else {
          row_matrix_offsets[row_vector_cnt] = static_cast<int64_t>(row_sizes[row_vector_cnt - 1] + row_matrix_offsets[row_vector_cnt - 1]);
        }
        row_vector_cnt++;
      }
    }
    // copy vector results to output for each row
    int64_t indptr_index = 0;
    for (int m = 0; m < num_matrices; ++m) {
      if (is_indptr_int32) {
        (reinterpret_cast<int32_t*>(*out_indptr))[indptr_index] = 0;
      } else {
        (reinterpret_cast<int64_t*>(*out_indptr))[indptr_index] = 0;
      }
      indptr_index++;
      int64_t matrix_start_index = m * static_cast<int64_t>(agg.size());
      OMP_INIT_EX();
      #pragma omp parallel for schedule(static)
      for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
        OMP_LOOP_EX_BEGIN();
        auto row_vector = agg[i];
        int64_t row_start_index = matrix_start_index + i;
        int64_t element_index = row_matrix_offsets[row_start_index];
        int64_t indptr_loop_index = indptr_index + i;
        for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) {
          (*out_indices)[element_index] = it->first;
          if (is_data_float32) {
            (reinterpret_cast<float*>(*out_data))[element_index] = static_cast<float>(it->second);
          } else {
            (reinterpret_cast<double*>(*out_data))[element_index] = it->second;
          }
          element_index++;
        }
        int64_t indptr_value = row_matrix_offsets[row_start_index] + row_sizes[row_start_index];
        if (is_indptr_int32) {
          (reinterpret_cast<int32_t*>(*out_indptr))[indptr_loop_index] = static_cast<int32_t>(indptr_value);
        } else {
          (reinterpret_cast<int64_t*>(*out_indptr))[indptr_loop_index] = indptr_value;
        }
        OMP_LOOP_EX_END();
      }
      OMP_THROW_EX();
      indptr_index += static_cast<int64_t>(agg.size());
    }
    out_len[0] = elements_size;
    out_len[1] = indptr_size;
  }

  void PredictSparseCSC(int num_iteration, int predict_type, int64_t nrow, int ncol,
                        std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
                        const Config& config,
                        int64_t* out_len, void** out_col_ptr, int col_ptr_type,
                        int32_t** out_indices, void** out_data, int data_type) {
    std::lock_guard<std::mutex> lock(mutex_);
    // Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
    int num_matrices = boosting_->NumModelPerIteration();
    auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
    auto pred_sparse_fun = predictor.GetPredictSparseFunction();
    bool is_col_ptr_int32 = false;
    bool is_data_float32 = false;
    int num_output_cols = ncol + 1;
    int col_ptr_size = (num_output_cols + 1) * num_matrices;
    if (col_ptr_type == C_API_DTYPE_INT32) {
      *out_col_ptr = new int32_t[col_ptr_size];
      is_col_ptr_int32 = true;
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
      *out_col_ptr = new int64_t[col_ptr_size];
    } else {
      Log::Fatal("Unknown col_ptr type in PredictSparseCSC");
      return;
    }
    // aggregated per row feature contribution results
    std::vector<std::vector<std::unordered_map<int, double>>> agg(nrow);
    int64_t elements_size = 0;
    PredictSparse(num_iteration, predict_type, nrow, ncol, get_row_fun, config, &elements_size, &agg,
                  out_indices, out_data, data_type, &is_data_float32, num_matrices);
    // calculate number of elements per column to construct
    // the CSC matrix with random access
    std::vector<std::vector<int64_t>> column_sizes(num_matrices);
    for (int m = 0; m < num_matrices; ++m) {
      column_sizes[m] = std::vector<int64_t>(num_output_cols, 0);
      for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
        auto row_vector = agg[i];
        for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) {
          column_sizes[m][it->first] += 1;
        }
      }
    }
    // keep track of column counts
    std::vector<std::vector<int64_t>> column_counts(num_matrices);
    // keep track of beginning index for each column
    std::vector<std::vector<int64_t>> column_start_indices(num_matrices);
    // keep track of beginning index for each matrix
    std::vector<int64_t> matrix_start_indices(num_matrices, 0);
    int col_ptr_index = 0;
    for (int m = 0; m < num_matrices; ++m) {
      int64_t col_ptr_value = 0;
      column_start_indices[m] = std::vector<int64_t>(num_output_cols, 0);
      column_counts[m] = std::vector<int64_t>(num_output_cols, 0);
      if (is_col_ptr_int32) {
        (reinterpret_cast<int32_t*>(*out_col_ptr))[col_ptr_index] = static_cast<int32_t>(col_ptr_value);
      } else {
        (reinterpret_cast<int64_t*>(*out_col_ptr))[col_ptr_index] = col_ptr_value;
      }
      col_ptr_index++;
      for (int64_t i = 1; i < static_cast<int64_t>(column_sizes[m].size()); ++i) {
        column_start_indices[m][i] = column_sizes[m][i - 1] + column_start_indices[m][i - 1];
        if (is_col_ptr_int32) {
          (reinterpret_cast<int32_t*>(*out_col_ptr))[col_ptr_index] = static_cast<int32_t>(column_start_indices[m][i]);
        } else {
          (reinterpret_cast<int64_t*>(*out_col_ptr))[col_ptr_index] = column_start_indices[m][i];
        }
        col_ptr_index++;
      }
      int64_t last_elem_index = static_cast<int64_t>(column_sizes[m].size()) - 1;
      int64_t last_column_start_index = column_start_indices[m][last_elem_index];
      int64_t last_column_size = column_sizes[m][last_elem_index];
      if (is_col_ptr_int32) {
        (reinterpret_cast<int32_t*>(*out_col_ptr))[col_ptr_index] = static_cast<int32_t>(last_column_start_index + last_column_size);
      } else {
        (reinterpret_cast<int64_t*>(*out_col_ptr))[col_ptr_index] = last_column_start_index + last_column_size;
      }
      if (m != 0) {
        matrix_start_indices[m] = matrix_start_indices[m - 1] +
          last_column_start_index +
          last_column_size;
      }
    }
    for (int m = 0; m < num_matrices; ++m) {
      for (int64_t i = 0; i < static_cast<int64_t>(agg.size()); ++i) {
        auto row_vector = agg[i];
        for (auto it = row_vector[m].begin(); it != row_vector[m].end(); ++it) {
          int64_t col_idx = it->first;
          int64_t element_index = column_start_indices[m][col_idx] +
            matrix_start_indices[m] +
            column_counts[m][col_idx];
          // store the row index
          (*out_indices)[element_index] = static_cast<int32_t>(i);
          // update column count
          column_counts[m][col_idx]++;
          if (is_data_float32) {
            (reinterpret_cast<float*>(*out_data))[element_index] = static_cast<float>(it->second);
          } else {
            (reinterpret_cast<double*>(*out_data))[element_index] = it->second;
          }
        }
      }
    }
    out_len[0] = elements_size;
    out_len[1] = col_ptr_size;
  }

Guolin Ke's avatar
Guolin Ke committed
666
  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
667
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
668
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
669
670
671
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
672
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
673
674
675
676
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
677
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
678
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
679
680
681
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
682
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
683
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
684
    bool bool_data_has_header = data_has_header > 0 ? true : false;
685
    predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
Guolin Ke's avatar
Guolin Ke committed
686
687
  }

Guolin Ke's avatar
Guolin Ke committed
688
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
689
690
691
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

692
693
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
694
  }
695

696
  void LoadModelFromString(const char* model_str) {
697
698
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
699
700
  }

701
702
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
703
704
  }

705
  std::string DumpModel(int start_iteration, int num_iteration) {
706
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
707
  }
708

709
710
711
712
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

713
714
715
716
717
718
719
720
721
722
  double UpperBoundValue() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->GetUpperBoundValue();
  }

  double LowerBoundValue() const {
    std::lock_guard<std::mutex> lock(mutex_);
    return boosting_->GetLowerBoundValue();
  }

Guolin Ke's avatar
Guolin Ke committed
723
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
724
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
725
726
727
728
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
729
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
730
731
  }

732
  void ShuffleModels(int start_iter, int end_iter) {
733
    std::lock_guard<std::mutex> lock(mutex_);
734
    boosting_->ShuffleModels(start_iter, end_iter);
735
736
  }

wxchan's avatar
wxchan committed
737
738
739
740
741
742
743
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
744

745
746
  int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
    *out_buffer_len = 0;
wxchan's avatar
wxchan committed
747
748
749
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
750
751
752
753
754
        if (idx < len) {
          std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len));
          out_strs[idx][buffer_len - 1] = '\0';
        }
        *out_buffer_len = std::max(name.size() + 1, *out_buffer_len);
wxchan's avatar
wxchan committed
755
756
757
758
759
760
        ++idx;
      }
    }
    return idx;
  }

761
762
  int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
    *out_buffer_len = 0;
wxchan's avatar
wxchan committed
763
764
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
765
766
767
768
769
      if (idx < len) {
        std::memcpy(out_strs[idx], name.c_str(), std::min(name.size() + 1, buffer_len));
        out_strs[idx][buffer_len - 1] = '\0';
      }
      *out_buffer_len = std::max(name.size() + 1, *out_buffer_len);
wxchan's avatar
wxchan committed
770
771
772
773
774
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
775
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
776

Nikita Titov's avatar
Nikita Titov committed
777
 private:
wxchan's avatar
wxchan committed
778
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
779
  std::unique_ptr<Boosting> boosting_;
780
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
781

Guolin Ke's avatar
Guolin Ke committed
782
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
783
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
784
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
785
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
786
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
787
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
788
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
789
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
790
  /*! \brief mutex for threading safe call */
791
  mutable std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
792
793
};

794
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
795

796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
// explicitly declare symbols from LightGBM namespace
using LightGBM::AllgatherFunction;
using LightGBM::Booster;
using LightGBM::Common::CheckElementsIntervalClosed;
using LightGBM::Common::RemoveQuotationSymbol;
using LightGBM::Common::Vector2Ptr;
using LightGBM::Common::VectorSize;
using LightGBM::Config;
using LightGBM::data_size_t;
using LightGBM::Dataset;
using LightGBM::DatasetLoader;
using LightGBM::kZeroThreshold;
using LightGBM::LGBM_APIHandleException;
using LightGBM::Log;
using LightGBM::Network;
using LightGBM::Random;
using LightGBM::ReduceScatterFunction;
Guolin Ke's avatar
Guolin Ke committed
813

Guolin Ke's avatar
Guolin Ke committed
814
815
816
817
818
819
820
821
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

822
823
824
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

825
826
template<typename T>
std::function<std::vector<std::pair<int, double>>(T idx)>
Guolin Ke's avatar
Guolin Ke committed
827
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
828
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
829
830
831

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
832
 public:
Guolin Ke's avatar
Guolin Ke committed
833
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
834
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
839
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
840
841

 private:
Guolin Ke's avatar
Guolin Ke committed
842
843
844
845
846
847
848
849
850
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
851
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
852
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
853
854
}

855
856
857
858
859
860
int LGBM_RegisterLogCallback(void (*callback)(const char*)) {
  API_BEGIN();
  Log::ResetCallBack(callback);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
861
int LGBM_DatasetCreateFromFile(const char* filename,
862
863
864
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
865
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
866
867
  auto param = Config::Str2Map(parameters);
  Config config;
868
869
870
871
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
872
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
873
  if (reference == nullptr) {
874
    if (Network::num_machines() == 1) {
875
      *out = loader.LoadFromFile(filename);
876
    } else {
877
      *out = loader.LoadFromFile(filename, Network::rank(), Network::num_machines());
878
    }
Guolin Ke's avatar
Guolin Ke committed
879
  } else {
880
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
881
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
882
  }
883
  API_END();
Guolin Ke's avatar
Guolin Ke committed
884
885
}

886

Guolin Ke's avatar
Guolin Ke committed
887
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
888
889
890
891
892
893
894
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
895
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
896
897
  auto param = Config::Str2Map(parameters);
  Config config;
898
899
900
901
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
902
  DatasetLoader loader(config, nullptr, 1, nullptr);
903
904
905
906
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
907
908
}

909

Guolin Ke's avatar
Guolin Ke committed
910
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
911
912
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
913
914
915
916
917
918
919
920
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
921
int LGBM_DatasetPushRows(DatasetHandle dataset,
922
923
924
925
926
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
927
928
929
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
930
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
931
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
932
  for (int i = 0; i < nrow; ++i) {
933
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
934
935
936
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
937
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
938
  }
939
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
940
941
942
943
944
945
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
946
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
947
948
949
950
951
952
953
954
955
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
956
957
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
958
  auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
959
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
960
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
961
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
962
  for (int i = 0; i < nrow; ++i) {
963
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
964
965
966
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
967
                          static_cast<data_size_t>(start_row + i), one_row);
968
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
969
  }
970
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
971
972
973
974
975
976
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
977
int LGBM_DatasetCreateFromMat(const void* data,
978
979
980
981
982
983
984
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
1006
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1007
1008
  auto param = Config::Str2Map(parameters);
  Config config;
1009
1010
1011
1012
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1013
  std::unique_ptr<Dataset> ret;
1014
1015
1016
1017
1018
1019
1020
1021
1022
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
1023

Guolin Ke's avatar
Guolin Ke committed
1024
1025
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
1026
    Random rand(config.data_random_seed);
1027
1028
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
1029
    sample_cnt = static_cast<int>(sample_indices.size());
1030
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
1031
    std::vector<std::vector<int>> sample_idx(ncol);
1032
1033
1034

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
1035
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1036
      auto idx = sample_indices[i];
1037
1038
1039
1040
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
1041

1042
1043
1044
1045
1046
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
1047
        }
Guolin Ke's avatar
Guolin Ke committed
1048
1049
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1050
    DatasetLoader loader(config, nullptr, 1, nullptr);
1051
1052
    ret.reset(loader.CostructFromSampleData(Vector2Ptr<double>(&sample_values).data(),
                                            Vector2Ptr<int>(&sample_idx).data(),
1053
                                            ncol,
1054
                                            VectorSize<double>(sample_values).data(),
1055
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
1056
  } else {
1057
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
1058
    ret->CreateValid(
1059
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
1060
  }
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
1075
1076
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1077
  *out = ret.release();
1078
  API_END();
1079
1080
}

Guolin Ke's avatar
Guolin Ke committed
1081
int LGBM_DatasetCreateFromCSR(const void* indptr,
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
1092
  API_BEGIN();
1093
1094
1095
1096
1097
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1098
1099
  auto param = Config::Str2Map(parameters);
  Config config;
1100
1101
1102
1103
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1104
  std::unique_ptr<Dataset> ret;
1105
  auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1106
1107
1108
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
1109
1110
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
1111
    auto sample_indices = rand.Sample(nrow, sample_cnt);
1112
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
1113
1114
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
1115
1116
1117
1118
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Nikita Titov's avatar
Nikita Titov committed
1119
        CHECK_LT(inner_data.first, num_col);
Guolin Ke's avatar
Guolin Ke committed
1120
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
1121
1122
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
1123
1124
1125
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1126
    DatasetLoader loader(config, nullptr, 1, nullptr);
1127
1128
    ret.reset(loader.CostructFromSampleData(Vector2Ptr<double>(&sample_values).data(),
                                            Vector2Ptr<int>(&sample_idx).data(),
1129
                                            static_cast<int>(num_col),
1130
                                            VectorSize<double>(sample_values).data(),
1131
                                            sample_cnt, nrow));
1132
  } else {
1133
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
1134
    ret->CreateValid(
1135
      reinterpret_cast<const Dataset*>(reference));
1136
  }
1137
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1138
  #pragma omp parallel for schedule(static)
1139
  for (int i = 0; i < nindptr - 1; ++i) {
1140
    OMP_LOOP_EX_BEGIN();
1141
1142
1143
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
1144
    OMP_LOOP_EX_END();
1145
  }
1146
  OMP_THROW_EX();
1147
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1148
  *out = ret.release();
1149
  API_END();
1150
1151
}

1152
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
1153
1154
1155
1156
1157
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
1158
  API_BEGIN();
1159
1160
1161
1162
1163
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
Nikita Titov's avatar
Nikita Titov committed
1187
        CHECK_LT(inner_data.first, num_col);
1188
1189
1190
1191
1192
1193
1194
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
1195
1196
    ret.reset(loader.CostructFromSampleData(Vector2Ptr<double>(&sample_values).data(),
                                            Vector2Ptr<int>(&sample_idx).data(),
1197
                                            static_cast<int>(num_col),
1198
                                            VectorSize<double>(sample_values).data(),
1199
1200
1201
1202
1203
1204
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
1205

1206
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1207
1208
  std::vector<std::pair<int, double>> thread_buffer;
  #pragma omp parallel for schedule(static) private(thread_buffer)
1209
1210
1211
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
1212
      const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1213
1214
      get_row_fun(i, thread_buffer);
      ret->PushOneRow(tid, i, thread_buffer);
1215
1216
1217
1218
1219
1220
1221
1222
1223
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1224
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
1235
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1236
1237
  auto param = Config::Str2Map(parameters);
  Config config;
1238
1239
1240
1241
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1242
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
1243
1244
1245
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
1246
1247
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
1248
    auto sample_indices = rand.Sample(nrow, sample_cnt);
1249
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
1250
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1251
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
1252
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1253
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1254
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1255
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1256
1257
1258
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
1259
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1260
1261
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
1262
1263
        }
      }
1264
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1265
    }
1266
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1267
    DatasetLoader loader(config, nullptr, 1, nullptr);
1268
1269
    ret.reset(loader.CostructFromSampleData(Vector2Ptr<double>(&sample_values).data(),
                                            Vector2Ptr<int>(&sample_idx).data(),
1270
                                            static_cast<int>(sample_values.size()),
1271
                                            VectorSize<double>(sample_values).data(),
1272
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
1273
  } else {
1274
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
1275
    ret->CreateValid(
1276
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
1277
  }
1278
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1279
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
1280
  for (int i = 0; i < ncol_ptr - 1; ++i) {
1281
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1282
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1283
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
1284
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1285
1286
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
1287
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
Guolin Ke's avatar
Guolin Ke committed
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
    auto bin_mapper = ret->FeatureBinMapper(feature_idx);
    if (bin_mapper->GetDefaultBin() == bin_mapper->GetMostFreqBin()) {
      int row_idx = 0;
      while (row_idx < nrow) {
        auto pair = col_it.NextNonZero();
        row_idx = pair.first;
        // no more data
        if (row_idx < 0) { break; }
        ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
      }
    } else {
      for (int row_idx = 0; row_idx < nrow; ++row_idx) {
        auto val = col_it.Get(row_idx);
        ret->PushOneData(tid, row_idx, group, sub_feature, val);
      }
Guolin Ke's avatar
Guolin Ke committed
1303
    }
1304
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1305
  }
1306
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1307
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1308
  *out = ret.release();
1309
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1310
1311
}

Guolin Ke's avatar
Guolin Ke committed
1312
int LGBM_DatasetGetSubset(
1313
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
1314
1315
1316
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
1317
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
1318
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1319
1320
  auto param = Config::Str2Map(parameters);
  Config config;
1321
1322
1323
1324
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
1325
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
1326
  CHECK_GT(num_used_row_indices, 0);
1327
1328
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
1329
  CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
1330
1331
1332
  if (!std::is_sorted(used_row_indices, used_row_indices + num_used_row_indices)) {
    Log::Fatal("used_row_indices should be sorted in Subset");
  }
Guolin Ke's avatar
Guolin Ke committed
1333
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
1334
  ret->CopyFeatureMapperFrom(full_dataset);
1335
  ret->CopySubrow(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
1336
1337
1338
1339
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1340
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
1341
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
1342
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
1343
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
1344
1345
1346
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
1347
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1348
1349
1350
1351
1352
1353
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1354
int LGBM_DatasetGetFeatureNames(
1355
1356
1357
1358
1359
1360
    DatasetHandle handle,
    const int len,
    int* num_feature_names,
    const size_t buffer_len,
    size_t* out_buffer_len,
    char** feature_names) {
1361
  API_BEGIN();
1362
  *out_buffer_len = 0;
1363
1364
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
1365
1366
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
1367
1368
1369
1370
1371
    if (i < len) {
      std::memcpy(feature_names[i], inside_feature_name[i].c_str(), std::min(inside_feature_name[i].size() + 1, buffer_len));
      feature_names[i][buffer_len - 1] = '\0';
    }
    *out_buffer_len = std::max(inside_feature_name[i].size() + 1, *out_buffer_len);
1372
1373
1374
1375
  }
  API_END();
}

1376
1377
1378
#ifdef _MSC_VER
  #pragma warning(disable : 4702)
#endif
Guolin Ke's avatar
Guolin Ke committed
1379
int LGBM_DatasetFree(DatasetHandle handle) {
1380
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1381
  delete reinterpret_cast<Dataset*>(handle);
1382
  API_END();
1383
1384
}

Guolin Ke's avatar
Guolin Ke committed
1385
int LGBM_DatasetSaveBinary(DatasetHandle handle,
1386
                           const char* filename) {
1387
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1388
1389
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
1390
  API_END();
1391
1392
}

1393
1394
1395
1396
1397
1398
1399
1400
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1401
int LGBM_DatasetSetField(DatasetHandle handle,
1402
1403
1404
1405
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
1406
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1407
  auto dataset = reinterpret_cast<Dataset*>(handle);
1408
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1409
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
1410
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1411
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
1412
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
1413
1414
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
1415
  }
1416
  if (!is_success) { Log::Fatal("Input data type error or field not found"); }
1417
  API_END();
1418
1419
}

Guolin Ke's avatar
Guolin Ke committed
1420
int LGBM_DatasetGetField(DatasetHandle handle,
1421
1422
1423
1424
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
1425
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1426
  auto dataset = reinterpret_cast<Dataset*>(handle);
1427
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1428
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1429
    *out_type = C_API_DTYPE_FLOAT32;
1430
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1431
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1432
    *out_type = C_API_DTYPE_INT32;
1433
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1434
1435
1436
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
Nikita Titov's avatar
Nikita Titov committed
1437
  }
1438
  if (!is_success) { Log::Fatal("Field not found"); }
wxchan's avatar
wxchan committed
1439
  if (*out_ptr == nullptr) { *out_len = 0; }
1440
  API_END();
1441
1442
}

1443
int LGBM_DatasetUpdateParamChecking(const char* old_parameters, const char* new_parameters) {
1444
  API_BEGIN();
1445
1446
1447
1448
1449
  auto old_param = Config::Str2Map(old_parameters);
  Config old_config;
  old_config.Set(old_param);
  auto new_param = Config::Str2Map(new_parameters);
  Booster::CheckDatasetResetConfig(old_config, new_param);
1450
1451
1452
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1453
int LGBM_DatasetGetNumData(DatasetHandle handle,
1454
                           int* out) {
1455
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1456
1457
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1458
  API_END();
1459
1460
}

Guolin Ke's avatar
Guolin Ke committed
1461
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1462
                              int* out) {
1463
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1464
1465
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1466
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1467
}
1468

1469
1470
1471
1472
1473
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
1474
  target_d->AddFeaturesFrom(source_d);
1475
1476
1477
  API_END();
}

1478
1479
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1480
int LGBM_BoosterCreate(const DatasetHandle train_data,
1481
1482
                       const char* parameters,
                       BoosterHandle* out) {
1483
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1484
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1485
1486
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1487
  API_END();
1488
1489
}

Guolin Ke's avatar
Guolin Ke committed
1490
int LGBM_BoosterCreateFromModelfile(
1491
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1492
  int* out_num_iterations,
1493
  BoosterHandle* out) {
1494
  API_BEGIN();
wxchan's avatar
wxchan committed
1495
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1496
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1497
  *out = ret.release();
1498
  API_END();
1499
1500
}

Guolin Ke's avatar
Guolin Ke committed
1501
int LGBM_BoosterLoadModelFromString(
1502
1503
1504
1505
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1506
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1507
1508
1509
1510
1511
1512
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1513
1514
1515
#ifdef _MSC_VER
  #pragma warning(disable : 4702)
#endif
Guolin Ke's avatar
Guolin Ke committed
1516
int LGBM_BoosterFree(BoosterHandle handle) {
1517
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1518
  delete reinterpret_cast<Booster*>(handle);
1519
  API_END();
1520
1521
}

1522
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1523
1524
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1525
  ref_booster->ShuffleModels(start_iter, end_iter);
1526
1527
1528
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1529
int LGBM_BoosterMerge(BoosterHandle handle,
1530
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1531
1532
1533
1534
1535
1536
1537
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1538
int LGBM_BoosterAddValidData(BoosterHandle handle,
1539
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1540
1541
1542
1543
1544
1545
1546
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1547
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1548
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1549
1550
1551
1552
1553
1554
1555
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1556
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1557
1558
1559
1560
1561
1562
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1563
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1564
1565
1566
1567
1568
1569
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1570
1571
1572
1573
1574
1575
1576
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1577
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1578
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1579
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1580
1581
1582
1583
1584
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1585
  API_END();
1586
1587
}

Guolin Ke's avatar
Guolin Ke committed
1588
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1589
1590
1591
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1592
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1593
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1594
  #ifdef SCORE_T_USE_DOUBLE
1595
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1596
  #else
1597
1598
1599
1600
1601
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1602
  #endif
1603
  API_END();
1604
1605
}

Guolin Ke's avatar
Guolin Ke committed
1606
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1607
1608
1609
1610
1611
1612
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1613
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1614
1615
1616
1617
1618
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1619

1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1634
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1635
1636
1637
1638
1639
1640
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

1641
1642
1643
1644
1645
1646
int LGBM_BoosterGetEvalNames(BoosterHandle handle,
                             const int len,
                             int* out_len,
                             const size_t buffer_len,
                             size_t* out_buffer_len,
                             char** out_strs) {
wxchan's avatar
wxchan committed
1647
1648
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1649
  *out_len = ref_booster->GetEvalNames(out_strs, len, buffer_len, out_buffer_len);
wxchan's avatar
wxchan committed
1650
1651
1652
  API_END();
}

1653
1654
1655
1656
1657
1658
int LGBM_BoosterGetFeatureNames(BoosterHandle handle,
                                const int len,
                                int* out_len,
                                const size_t buffer_len,
                                size_t* out_buffer_len,
                                char** out_strs) {
wxchan's avatar
wxchan committed
1659
1660
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1661
  *out_len = ref_booster->GetFeatureNames(out_strs, len, buffer_len, out_buffer_len);
wxchan's avatar
wxchan committed
1662
1663
1664
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1665
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1666
1667
1668
1669
1670
1671
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1672
int LGBM_BoosterGetEval(BoosterHandle handle,
1673
1674
1675
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1676
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1677
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1678
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1679
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1680
  *out_len = static_cast<int>(result_buf.size());
1681
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1682
    (out_results)[i] = static_cast<double>(result_buf[i]);
1683
  }
1684
  API_END();
1685
1686
}

Guolin Ke's avatar
Guolin Ke committed
1687
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1688
1689
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1690
1691
1692
1693
1694
1695
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1696
int LGBM_BoosterGetPredict(BoosterHandle handle,
1697
1698
1699
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1700
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1701
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1702
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1703
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1704
1705
}

Guolin Ke's avatar
Guolin Ke committed
1706
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1707
1708
1709
1710
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1711
                               const char* parameter,
1712
                               const char* result_filename) {
1713
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1714
1715
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1716
1717
1718
1719
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1720
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1721
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1722
                       config, result_filename);
1723
  API_END();
1724
1725
}

Guolin Ke's avatar
Guolin Ke committed
1726
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1727
1728
1729
1730
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1731
1732
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1733
1734
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1735
1736
1737
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1738
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1739
1740
1741
1742
1743
1744
1745
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
1746
                              int64_t num_col,
1747
1748
                              int predict_type,
                              int num_iteration,
1749
                              const char* parameter,
1750
1751
                              int64_t* out_len,
                              double* out_result) {
1752
  API_BEGIN();
1753
1754
1755
1756
1757
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1758
1759
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1760
1761
1762
1763
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1764
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1765
  auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1766
  int nrow = static_cast<int>(nindptr - 1);
1767
  ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1768
                       config, out_result, out_len);
1769
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1770
}
1771

1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
int LGBM_BoosterPredictSparseOutput(BoosterHandle handle,
                                    const void* indptr,
                                    int indptr_type,
                                    const int32_t* indices,
                                    const void* data,
                                    int data_type,
                                    int64_t nindptr,
                                    int64_t nelem,
                                    int64_t num_col_or_row,
                                    int predict_type,
                                    int num_iteration,
                                    const char* parameter,
                                    int matrix_type,
                                    int64_t* out_len,
                                    void** out_indptr,
                                    int32_t** out_indices,
                                    void** out_data) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  if (matrix_type == C_API_MATRIX_TYPE_CSR) {
    if (num_col_or_row <= 0) {
      Log::Fatal("The number of columns should be greater than zero.");
    } else if (num_col_or_row >= INT32_MAX) {
      Log::Fatal("The number of columns should be smaller than INT32_MAX.");
    }
    auto get_row_fun = RowFunctionFromCSR<int64_t>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
    int64_t nrow = nindptr - 1;
    ref_booster->PredictSparseCSR(num_iteration, predict_type, nrow, static_cast<int>(num_col_or_row), get_row_fun,
                                  config, out_len, out_indptr, indptr_type, out_indices, out_data, data_type);
  } else if (matrix_type == C_API_MATRIX_TYPE_CSC) {
    int num_threads = OMP_NUM_THREADS();
    int ncol = static_cast<int>(nindptr - 1);
    std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
    for (int i = 0; i < num_threads; ++i) {
      for (int j = 0; j < ncol; ++j) {
        iterators[i].emplace_back(indptr, indptr_type, indices, data, data_type, nindptr, nelem, j);
      }
    }
    std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun =
      [&iterators, ncol](int64_t i) {
      std::vector<std::pair<int, double>> one_row;
      one_row.reserve(ncol);
      const int tid = omp_get_thread_num();
      for (int j = 0; j < ncol; ++j) {
        auto val = iterators[tid][j].Get(static_cast<int>(i));
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
          one_row.emplace_back(j, val);
        }
      }
      return one_row;
    };
    ref_booster->PredictSparseCSC(num_iteration, predict_type, num_col_or_row, ncol, get_row_fun, config,
                                  out_len, out_indptr, indptr_type, out_indices, out_data, data_type);
  } else {
    Log::Fatal("Unknown matrix type in LGBM_BoosterPredictSparseOutput");
  }
  API_END();
}

int LGBM_BoosterFreePredictSparse(void* indptr, int32_t* indices, void* data, int indptr_type, int data_type) {
  API_BEGIN();
  if (indptr_type == C_API_DTYPE_INT32) {
    delete reinterpret_cast<int32_t*>(indptr);
  } else if (indptr_type == C_API_DTYPE_INT64) {
    delete reinterpret_cast<int64_t*>(indptr);
  } else {
    Log::Fatal("Unknown indptr type in LGBM_BoosterFreePredictSparse");
  }
  delete indices;
  if (data_type == C_API_DTYPE_FLOAT32) {
    delete reinterpret_cast<float*>(data);
  } else if (data_type == C_API_DTYPE_FLOAT64) {
    delete reinterpret_cast<double*>(data);
  } else {
    Log::Fatal("Unknown data type in LGBM_BoosterFreePredictSparse");
  }
  API_END();
}

1857
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1858
1859
1860
1861
1862
1863
1864
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
1865
                                       int64_t num_col,
1866
1867
1868
1869
1870
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1871
  API_BEGIN();
1872
1873
1874
1875
1876
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1877
1878
1879
1880
1881
1882
1883
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1884
  auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1885
  ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
1886
1887
1888
1889
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1890
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1901
                              const char* parameter,
1902
1903
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1904
1905
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1906
1907
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1908
1909
1910
1911
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
1912
  int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
1913
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1914
1915
1916
1917
1918
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1919
1920
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
Guolin Ke's avatar
Guolin Ke committed
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
      [&iterators, ncol](int i) {
        std::vector<std::pair<int, double>> one_row;
        one_row.reserve(ncol);
        const int tid = omp_get_thread_num();
        for (int j = 0; j < ncol; ++j) {
          auto val = iterators[tid][j].Get(i);
          if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
            one_row.emplace_back(j, val);
          }
        }
        return one_row;
      };
1933
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
cbecker's avatar
cbecker committed
1934
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1935
1936
1937
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1938
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1939
1940
1941
1942
1943
1944
1945
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1946
                              const char* parameter,
1947
1948
                              int64_t* out_len,
                              double* out_result) {
1949
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1950
1951
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1952
1953
1954
1955
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1956
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1957
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1958
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1959
                       config, out_result, out_len);
1960
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1961
}
1962

1963
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1964
1965
1966
1967
1968
1969
1970
1971
1972
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1973
1974
1975
1976
1977
1978
1979
1980
1981
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1982
  ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
1983
1984
1985
1986
  API_END();
}


1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
2006
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
2007
2008
2009
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
2010
int LGBM_BoosterSaveModel(BoosterHandle handle,
2011
                          int start_iteration,
2012
2013
                          int num_iteration,
                          const char* filename) {
2014
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
2015
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
2016
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
2017
2018
2019
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
2020
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
2021
                                  int start_iteration,
2022
                                  int num_iteration,
2023
                                  int64_t buffer_len,
2024
                                  int64_t* out_len,
2025
                                  char* out_str) {
2026
2027
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
2028
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
2029
  *out_len = static_cast<int64_t>(model.size()) + 1;
2030
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
2031
    std::memcpy(out_str, model.c_str(), *out_len);
2032
2033
2034
2035
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
2036
int LGBM_BoosterDumpModel(BoosterHandle handle,
2037
                          int start_iteration,
2038
                          int num_iteration,
2039
2040
                          int64_t buffer_len,
                          int64_t* out_len,
2041
                          char* out_str) {
wxchan's avatar
wxchan committed
2042
2043
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
2044
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
2045
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
2046
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
2047
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
2048
  }
2049
  API_END();
Guolin Ke's avatar
Guolin Ke committed
2050
}
2051

Guolin Ke's avatar
Guolin Ke committed
2052
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
2053
2054
2055
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
2056
2057
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
2058
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
2059
2060
2061
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
2062
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
2063
2064
2065
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
2066
2067
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
2068
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
2069
2070
2071
  API_END();
}

2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
int LGBM_BoosterGetUpperBoundValue(BoosterHandle handle,
                                   double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  double max_value = ref_booster->UpperBoundValue();
  *out_results = max_value;
  API_END();
}

int LGBM_BoosterGetLowerBoundValue(BoosterHandle handle,
                                   double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  double min_value = ref_booster->LowerBoundValue();
  *out_results = min_value;
  API_END();
}

2103
2104
2105
2106
2107
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
2108
  Config config;
2109
  config.machines = RemoveQuotationSymbol(std::string(machines));
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

2125
2126
2127
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
2128
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
2129
  if (num_machines > 1) {
2130
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
2131
2132
2133
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
2134

Guolin Ke's avatar
Guolin Ke committed
2135
// ---- start of some help functions
2136
2137
2138

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
2139
  if (data_type == C_API_DTYPE_FLOAT32) {
2140
2141
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
2142
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
2143
        std::vector<double> ret(num_col);
2144
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
2145
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
2146
          ret[i] = static_cast<double>(*(tmp_ptr + i));
2147
2148
2149
2150
        }
        return ret;
      };
    } else {
2151
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
2152
        std::vector<double> ret(num_col);
2153
        for (int i = 0; i < num_col; ++i) {
2154
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
2155
2156
2157
2158
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
2159
  } else if (data_type == C_API_DTYPE_FLOAT64) {
2160
2161
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
2162
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
2163
        std::vector<double> ret(num_col);
2164
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
2165
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
2166
          ret[i] = static_cast<double>(*(tmp_ptr + i));
2167
2168
2169
2170
        }
        return ret;
      };
    } else {
2171
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
2172
        std::vector<double> ret(num_col);
2173
        for (int i = 0; i < num_col; ++i) {
2174
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
2175
2176
2177
2178
2179
        }
        return ret;
      };
    }
  }
2180
  Log::Fatal("Unknown data type in RowFunctionFromDenseMatric");
2181
  return nullptr;
2182
2183
2184
2185
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
2186
2187
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
2188
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
2189
2190
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
Guolin Ke's avatar
Guolin Ke committed
2191
      ret.reserve(raw_values.size());
Guolin Ke's avatar
Guolin Ke committed
2192
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
2193
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
2194
          ret.emplace_back(i, raw_values[i]);
2195
        }
Guolin Ke's avatar
Guolin Ke committed
2196
2197
2198
      }
      return ret;
    };
2199
  }
Guolin Ke's avatar
Guolin Ke committed
2200
  return nullptr;
2201
2202
}

2203
2204
2205
2206
2207
2208
2209
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
Guolin Ke's avatar
Guolin Ke committed
2210
    ret.reserve(raw_values.size());
2211
2212
2213
2214
2215
2216
2217
2218
2219
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

2220
2221
template<typename T>
std::function<std::vector<std::pair<int, double>>(T idx)>
2222
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
2223
  if (data_type == C_API_DTYPE_FLOAT32) {
2224
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
2225
    if (indptr_type == C_API_DTYPE_INT32) {
2226
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
2227
      return [=] (T idx) {
2228
2229
2230
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
2231
2232
2233
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
2234
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
2235
          ret.emplace_back(indices[i], data_ptr[i]);
2236
2237
2238
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
2239
    } else if (indptr_type == C_API_DTYPE_INT64) {
2240
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
2241
      return [=] (T idx) {
2242
2243
2244
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
2245
2246
2247
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
2248
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
2249
          ret.emplace_back(indices[i], data_ptr[i]);
2250
2251
2252
2253
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
2254
  } else if (data_type == C_API_DTYPE_FLOAT64) {
2255
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
2256
    if (indptr_type == C_API_DTYPE_INT32) {
2257
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
2258
      return [=] (T idx) {
2259
2260
2261
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
2262
2263
2264
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
2265
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
2266
          ret.emplace_back(indices[i], data_ptr[i]);
2267
2268
2269
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
2270
    } else if (indptr_type == C_API_DTYPE_INT64) {
2271
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
2272
      return [=] (T idx) {
2273
2274
2275
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
2276
2277
2278
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
2279
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
2280
          ret.emplace_back(indices[i], data_ptr[i]);
2281
2282
2283
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
2284
2285
    }
  }
2286
  Log::Fatal("Unknown data type in RowFunctionFromCSR");
2287
  return nullptr;
2288
2289
}

Guolin Ke's avatar
Guolin Ke committed
2290
std::function<std::pair<int, double>(int idx)>
2291
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
2292
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
2293
  if (data_type == C_API_DTYPE_FLOAT32) {
2294
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
2295
    if (col_ptr_type == C_API_DTYPE_INT32) {
2296
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
2297
2298
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
2299
2300
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
2301
2302
        if (i >= end) {
          return std::make_pair(-1, 0.0);
2303
        }
Guolin Ke's avatar
Guolin Ke committed
2304
2305
2306
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
2307
      };
Guolin Ke's avatar
Guolin Ke committed
2308
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
2309
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
2310
2311
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
2312
2313
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
2314
2315
        if (i >= end) {
          return std::make_pair(-1, 0.0);
2316
        }
Guolin Ke's avatar
Guolin Ke committed
2317
2318
2319
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
2320
      };
Guolin Ke's avatar
Guolin Ke committed
2321
    }
Guolin Ke's avatar
Guolin Ke committed
2322
  } else if (data_type == C_API_DTYPE_FLOAT64) {
2323
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
2324
    if (col_ptr_type == C_API_DTYPE_INT32) {
2325
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
2326
2327
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
2328
2329
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
2330
2331
        if (i >= end) {
          return std::make_pair(-1, 0.0);
2332
        }
Guolin Ke's avatar
Guolin Ke committed
2333
2334
2335
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
2336
      };
Guolin Ke's avatar
Guolin Ke committed
2337
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
2338
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
2339
2340
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
2341
2342
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
2343
2344
        if (i >= end) {
          return std::make_pair(-1, 0.0);
2345
        }
Guolin Ke's avatar
Guolin Ke committed
2346
2347
2348
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
2349
      };
Guolin Ke's avatar
Guolin Ke committed
2350
2351
    }
  }
2352
  Log::Fatal("Unknown data type in CSC matrix");
2353
  return nullptr;
2354
2355
}

Guolin Ke's avatar
Guolin Ke committed
2356
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
2357
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
2358
2359
2360
2361
2362
2363
2364
2365
2366
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
2367
    }
Guolin Ke's avatar
Guolin Ke committed
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
2385
    }
Guolin Ke's avatar
Guolin Ke committed
2386
2387
2388
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
2389
  }
Guolin Ke's avatar
Guolin Ke committed
2390
}