application.h 2.06 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#ifndef LIGHTGBM_APPLICATION_H_
#define LIGHTGBM_APPLICATION_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class Boosting;
class ObjectiveFunction;
class Metric;

/*!
* \brief The entrance of LightGBM. this application has two tasks:
* Train and Predict.
* Train task will train a new model
Hui Xue's avatar
Hui Xue committed
21
* Predict task will predict the scores of test data and save the score to local disk
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
*/
class Application {
public:
  Application(int argc, char** argv);

  /*! \brief Destructor */
  ~Application();

  /*! \brief To call this funciton  to run application*/
  inline void Run();

private:
  /*! 
  * \brief Global Sync by minimal, will return minimal of global
  * \param local Local data
  * \return Global minimal data
  */
  template<typename T>
  T GlobalSyncUpByMin(T& local);

  /*! \brief Load parametes from command line and config file*/
  void LoadParameters(int argc, char** argv);

  /*! \brief Load data, including training data and validation data*/
  void LoadData();

  /*! \brief Some initial works before training*/
  void InitTrain();

  /*! \brief The training logic */
  void Train();

  /*! \brief Initialize the enviroment needed by prediction */
  void InitPredict();

  /*! \brief Load model */
  void LoadModel();

  /*! \brief The prediction logic */
  void Predict();

  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Training data */
  Dataset* train_data_;
  /*! \brief Validation data */
  std::vector<Dataset*> valid_datas_;
  /*! \brief Metric for training data */
  std::vector<Metric*> train_metric_;
  /*! \brief Metrics for validation data */
  std::vector<std::vector<Metric*>> valid_metrics_;
  /*! \brief Boosting object */
  Boosting* boosting_;
  /*! \brief Training objective function */
  ObjectiveFunction* objective_fun_;
};


inline void Application::Run() {
  if (config_.task_type == TaskType::kPredict) {
    InitPredict();
    Predict();
  } else {
    InitTrain();
    Train();
  }
}

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
92
#endif   // LightGBM_APPLICATION_H_