application.h 2.08 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#ifndef LIGHTGBM_APPLICATION_H_
#define LIGHTGBM_APPLICATION_H_

#include <LightGBM/meta.h>
#include <LightGBM/config.h>

#include <vector>

namespace LightGBM {

/*! \brief forward declaration */
class Dataset;
class Boosting;
class ObjectiveFunction;
class Metric;

/*!
* \brief The entrance of LightGBM. this application has two tasks:
* Train and Predict.
* Train task will train a new model
* Predict task will predicting the scores of test data then saving the score to local disk
*/
class Application {
public:
  Application(int argc, char** argv);

  /*! \brief Destructor */
  ~Application();

  /*! \brief To call this funciton  to run application*/
  inline void Run();

private:
  /*! 
  * \brief Global Sync by minimal, will return minimal of global
  * \param local Local data
  * \return Global minimal data
  */
  template<typename T>
  T GlobalSyncUpByMin(T& local);

  /*! \brief Load parametes from command line and config file*/
  void LoadParameters(int argc, char** argv);

  /*! \brief Load data, including training data and validation data*/
  void LoadData();

  /*! \brief Some initial works before training*/
  void InitTrain();

  /*! \brief The training logic */
  void Train();

  /*! \brief Initialize the enviroment needed by prediction */
  void InitPredict();

  /*! \brief Load model */
  void LoadModel();

  /*! \brief The prediction logic */
  void Predict();

  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Training data */
  Dataset* train_data_;
  /*! \brief Validation data */
  std::vector<Dataset*> valid_datas_;
  /*! \brief Metric for training data */
  std::vector<Metric*> train_metric_;
  /*! \brief Metrics for validation data */
  std::vector<std::vector<Metric*>> valid_metrics_;
  /*! \brief Boosting object */
  Boosting* boosting_;
  /*! \brief Training objective function */
  ObjectiveFunction* objective_fun_;
};


inline void Application::Run() {
  if (config_.task_type == TaskType::kPredict) {
    InitPredict();
    Predict();
  } else {
    InitTrain();
    Train();
  }
}

}  // namespace LightGBM

#endif  #endif  // LightGBM_APPLICATION_H_