application.h 2.29 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_APPLICATION_H_
#define LIGHTGBM_APPLICATION_H_

#include <LightGBM/config.h>
9
#include <LightGBM/meta.h>
Guolin Ke's avatar
Guolin Ke committed
10

Guolin Ke's avatar
Guolin Ke committed
11
#include <memory>
12
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
13
14
15

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
16
class DatasetLoader;
Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
21
22
class Dataset;
class Boosting;
class ObjectiveFunction;
class Metric;

/*!
Qiwei Ye's avatar
Qiwei Ye committed
23
* \brief The main entrance of LightGBM. this application has two tasks:
24
25
*        Train and Predict.
*        Train task will train a new model
Seong-Jin Kim's avatar
Seong-Jin Kim committed
26
*        Predict task will predict the scores of test data using existing model,
zhangyafeikimi's avatar
zhangyafeikimi committed
27
*        and save the score to disk.
Guolin Ke's avatar
Guolin Ke committed
28
29
*/
class Application {
Nikita Titov's avatar
Nikita Titov committed
30
 public:
Guolin Ke's avatar
Guolin Ke committed
31
32
33
34
35
  Application(int argc, char** argv);

  /*! \brief Destructor */
  ~Application();

36
  /*! \brief To call this function to run application*/
Guolin Ke's avatar
Guolin Ke committed
37
38
  inline void Run();

Nikita Titov's avatar
Nikita Titov committed
39
 private:
zhangyafeikimi's avatar
zhangyafeikimi committed
40
  /*! \brief Load parameters from command line and config file*/
Guolin Ke's avatar
Guolin Ke committed
41
42
43
44
45
  void LoadParameters(int argc, char** argv);

  /*! \brief Load data, including training data and validation data*/
  void LoadData();

Qiwei Ye's avatar
Qiwei Ye committed
46
  /*! \brief Initialization before training*/
Guolin Ke's avatar
Guolin Ke committed
47
48
  void InitTrain();

Qiwei Ye's avatar
Qiwei Ye committed
49
  /*! \brief Main Training logic */
Guolin Ke's avatar
Guolin Ke committed
50
51
  void Train();

Qiwei Ye's avatar
Qiwei Ye committed
52
  /*! \brief Initializations before prediction */
Guolin Ke's avatar
Guolin Ke committed
53
54
  void InitPredict();

Qiwei Ye's avatar
Qiwei Ye committed
55
  /*! \brief Main predicting logic */
Guolin Ke's avatar
Guolin Ke committed
56
57
  void Predict();

58
59
60
  /*! \brief Main Convert model logic */
  void ConvertModel();

Guolin Ke's avatar
Guolin Ke committed
61
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
62
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
63
  /*! \brief Training data */
Guolin Ke's avatar
Guolin Ke committed
64
  std::unique_ptr<Dataset> train_data_;
Guolin Ke's avatar
Guolin Ke committed
65
  /*! \brief Validation data */
Guolin Ke's avatar
Guolin Ke committed
66
  std::vector<std::unique_ptr<Dataset>> valid_datas_;
Guolin Ke's avatar
Guolin Ke committed
67
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
68
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
69
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
70
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
71
  /*! \brief Boosting object */
Guolin Ke's avatar
Guolin Ke committed
72
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
73
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
74
  std::unique_ptr<ObjectiveFunction> objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
};


inline void Application::Run() {
Guolin Ke's avatar
Guolin Ke committed
79
  if (config_.task == TaskType::kPredict || config_.task == TaskType::KRefitTree) {
Guolin Ke's avatar
Guolin Ke committed
80
81
    InitPredict();
    Predict();
Guolin Ke's avatar
Guolin Ke committed
82
  } else if (config_.task == TaskType::kConvertModel) {
83
    ConvertModel();
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
88
89
90
91
  } else {
    InitTrain();
    Train();
  }
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
92
#endif   // LightGBM_APPLICATION_H_