Unverified Commit 8a5ec366 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Speed up saving and loading model (#1083)

* remove protobuf

* add version number

* remove pmml script

* use float for split gain

* fix warnings

* refine the read model logic of gbdt

* fix compile error

* improve decode speed

* fix some bugs

* fix double accuracy problem

* fix bug

* multi-thread save model

* speed up save model to string

* parallel save/load model

* fix some warnings.

* fix warnings.

* fix a bug

* remove debug output

* fix doc

* fix max_bin warning in tests.

* fix max_bin warning

* fix pylint

* clean code for stringToArray

* clean code for TToString

* remove max_bin

* replace "class" with typename
parent 8d016c12
...@@ -23,7 +23,6 @@ env: ...@@ -23,7 +23,6 @@ env:
- TASK=if-else - TASK=if-else
- TASK=sdist PYTHON_VERSION=3.4 - TASK=sdist PYTHON_VERSION=3.4
- TASK=bdist PYTHON_VERSION=3.5 - TASK=bdist PYTHON_VERSION=3.5
- TASK=proto
- TASK=gpu METHOD=source - TASK=gpu METHOD=source
- TASK=gpu METHOD=pip - TASK=gpu METHOD=pip
...@@ -39,8 +38,6 @@ matrix: ...@@ -39,8 +38,6 @@ matrix:
env: TASK=pylint env: TASK=pylint
- os: osx - os: osx
env: TASK=check-docs env: TASK=check-docs
- os: osx
env: TASK=proto
before_install: before_install:
- test -n $CC && unset CC - test -n $CC && unset CC
......
...@@ -62,18 +62,6 @@ if [[ ${TASK} == "if-else" ]]; then ...@@ -62,18 +62,6 @@ if [[ ${TASK} == "if-else" ]]; then
exit 0 exit 0
fi fi
if [[ ${TASK} == "proto" ]]; then
conda install numpy
source activate test-env
mkdir build && cd build && cmake .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR && git clone https://github.com/google/protobuf && cd protobuf && ./autogen.sh && ./configure && make && sudo make install && sudo ldconfig
cd $TRAVIS_BUILD_DIR/build && rm -rf * && cmake -DUSE_PROTO=ON .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf model_format=proto && ../../lightgbm config=predict.conf output_result=proto.pred model_format=proto || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && python test.py || exit -1
exit 0
fi
conda install numpy nose scipy scikit-learn pandas matplotlib pytest conda install numpy nose scipy scikit-learn pandas matplotlib pytest
if [[ ${TASK} == "sdist" ]]; then if [[ ${TASK} == "sdist" ]]; then
......
...@@ -124,25 +124,8 @@ file(GLOB SOURCES ...@@ -124,25 +124,8 @@ file(GLOB SOURCES
src/treelearner/*.cpp src/treelearner/*.cpp
) )
if (USE_PROTO) add_executable(lightgbm src/main.cpp ${SOURCES})
if(MSVC) add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES})
message(FATAL_ERROR "Cannot use proto with MSVC.")
endif(MSVC)
find_package(Protobuf REQUIRED)
PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS proto/model.proto)
include_directories(${PROTOBUF_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
ADD_DEFINITIONS(-DUSE_PROTO)
SET(PROTO_FILES src/proto/gbdt_model_proto.cpp ${PROTO_HDRS} ${PROTO_SRCS})
endif(USE_PROTO)
add_executable(lightgbm src/main.cpp ${SOURCES} ${PROTO_FILES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES} ${PROTO_FILES})
if (USE_PROTO)
TARGET_LINK_LIBRARIES(lightgbm ${PROTOBUF_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${PROTOBUF_LIBRARIES})
endif(USE_PROTO)
if(MSVC) if(MSVC)
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm") set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm")
......
...@@ -271,21 +271,6 @@ Following procedure is for the MSVC (Microsoft Visual C++) build. ...@@ -271,21 +271,6 @@ Following procedure is for the MSVC (Microsoft Visual C++) build.
**Note**: ``C:\local\boost_1_64_0\`` and ``C:\local\boost_1_64_0\lib64-msvc-14.0`` are locations of your Boost binaries. You also can set them to the environment variable to avoid ``Set ...`` commands when build. **Note**: ``C:\local\boost_1_64_0\`` and ``C:\local\boost_1_64_0\lib64-msvc-14.0`` are locations of your Boost binaries. You also can set them to the environment variable to avoid ``Set ...`` commands when build.
Protobuf Support
^^^^^^^^^^^^^^^^
If you want to use protobuf to save and load models, install `protobuf c++ version <https://github.com/google/protobuf/blob/master/src/README.md>`__ first.
Then run cmake with USE_PROTO on, for example:
.. code::
cmake -DUSE_PROTO=ON ..
You can then use ``model_format=proto`` in parameters when save and load models.
**Note**: for windows user, it's only tested with mingw.
Docker Docker
^^^^^^ ^^^^^^
......
...@@ -335,20 +335,6 @@ IO Parameters ...@@ -335,20 +335,6 @@ IO Parameters
- file name of prediction result in ``prediction`` task - file name of prediction result in ``prediction`` task
- ``model_format``, default=\ ``text``, type=multi-enum, options=\ ``text``, ``proto``
- format to save and load model
- if ``text``, text string will be used
- if ``proto``, Protocol Buffer binary format will be used
- you can save in multiple formats by joining them with comma, like ``text,proto``. In this case, ``model_format`` will be add as suffix after ``output_model``
- **Note**: loading with multiple formats is not supported
- **Note**: to use this parameter you need to `build version with Protobuf Support <./Installation-Guide.rst#protobuf-support>`__
- ``pre_partition``, default=\ ``false``, type=bool, alias=\ ``is_pre_partition`` - ``pre_partition``, default=\ ``false``, type=bool, alias=\ ``is_pre_partition``
- used for parallel learning (not include feature parallel) - used for parallel learning (not include feature parallel)
......
...@@ -4,10 +4,6 @@ ...@@ -4,10 +4,6 @@
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/config.h> #include <LightGBM/config.h>
#ifdef USE_PROTO
#include "model.pb.h"
#endif // USE_PROTO
#include <vector> #include <vector>
#include <string> #include <string>
#include <map> #include <map>
...@@ -198,26 +194,11 @@ public: ...@@ -198,26 +194,11 @@ public:
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
* \param model_str The string of model * \param buffer The content of model
* \return true if succeeded * \param len The length of buffer
*/
virtual bool LoadModelFromString(const std::string& model_str) = 0;
#ifdef USE_PROTO
/*!
* \brief Save model with protobuf
* \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to
*/
virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0;
/*!
* \brief Restore from a serialized protobuf file
* \param filename Filename that want to restore from
* \return true if succeeded * \return true if succeeded
*/ */
virtual bool LoadModelFromProto(const char* filename) = 0; virtual bool LoadModelFromString(const char* buffer, size_t len) = 0;
#endif // USE_PROTO
/*! /*!
* \brief Calculate feature importances * \brief Calculate feature importances
...@@ -283,7 +264,7 @@ public: ...@@ -283,7 +264,7 @@ public:
/*! \brief Disable copy */ /*! \brief Disable copy */
Boosting(const Boosting&) = delete; Boosting(const Boosting&) = delete;
static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename); static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
/*! /*!
* \brief Create boosting object * \brief Create boosting object
...@@ -293,7 +274,7 @@ public: ...@@ -293,7 +274,7 @@ public:
* \param filename name of model file, if existing will continue to train from this model * \param filename name of model file, if existing will continue to train from this model
* \return The boosting object * \return The boosting object
*/ */
static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename); static Boosting* CreateBoosting(const std::string& type, const char* filename);
}; };
......
...@@ -105,7 +105,6 @@ public: ...@@ -105,7 +105,6 @@ public:
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp"; std::string convert_model = "gbdt_prediction.cpp";
std::string input_model = ""; std::string input_model = "";
std::string model_format = "text";
int verbosity = 1; int verbosity = 1;
int num_iteration_predict = -1; int num_iteration_predict = -1;
bool is_pre_partition = false; bool is_pre_partition = false;
...@@ -449,7 +448,7 @@ struct ParameterAlias { ...@@ -449,7 +448,7 @@ struct ParameterAlias {
const std::unordered_set<std::string> parameter_set({ const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device", "config", "config_file", "task", "device",
"num_threads", "seed", "boosting_type", "objective", "data", "num_threads", "seed", "boosting_type", "objective", "data",
"output_model", "input_model", "output_result", "model_format", "valid_data", "output_model", "input_model", "output_result", "valid_data",
"is_enable_sparse", "is_pre_partition", "is_training_metric", "is_enable_sparse", "is_pre_partition", "is_training_metric",
"ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf", "ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf",
"num_leaves", "feature_fraction", "num_iterations", "num_leaves", "feature_fraction", "num_iterations",
......
...@@ -3,9 +3,6 @@ ...@@ -3,9 +3,6 @@
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include <LightGBM/dataset.h> #include <LightGBM/dataset.h>
#ifdef USE_PROTO
#include "model.pb.h"
#endif // USE_PROTO
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -32,15 +29,9 @@ public: ...@@ -32,15 +29,9 @@ public:
/*! /*!
* \brief Construtor, from a string * \brief Construtor, from a string
* \param str Model string * \param str Model string
* \param used_len used count of str
*/ */
explicit Tree(const std::string& str); Tree(const char* str, size_t* used_len);
#ifdef USE_PROTO
/*!
* \brief Construtor, from a protobuf object
* \param model_tree Model protobuf object
*/
explicit Tree(const Model_Tree& model_tree);
#endif // USE_PROTO
~Tree(); ~Tree();
...@@ -62,7 +53,7 @@ public: ...@@ -62,7 +53,7 @@ public:
*/ */
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value, double threshold_double, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type, bool default_left); int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left);
/*! /*!
* \brief Performing a split on tree leaves, with categorical feature * \brief Performing a split on tree leaves, with categorical feature
...@@ -82,7 +73,7 @@ public: ...@@ -82,7 +73,7 @@ public:
*/ */
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin, int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value, const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type); int left_cnt, int right_cnt, float gain, MissingType missing_type);
/*! \brief Get the output of one leaf */ /*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; } inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
...@@ -179,11 +170,6 @@ public: ...@@ -179,11 +170,6 @@ public:
/*! \brief Serialize this object to if-else statement*/ /*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const; std::string ToIfElse(int index, bool is_predict_leaf_index) const;
#ifdef USE_PROTO
/*! \brief Serialize this object to protobuf object*/
void ToProto(Model_Tree& model_tree) const;
#endif // USE_PROTO
inline static bool IsZero(double fval) { inline static bool IsZero(double fval) {
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) { if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
return true; return true;
...@@ -304,7 +290,7 @@ private: ...@@ -304,7 +290,7 @@ private:
} }
inline void Split(int leaf, int feature, int real_feature, inline void Split(int leaf, int feature, int real_feature,
double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain); double left_value, double right_value, int left_cnt, int right_cnt, float gain);
/*! /*!
* \brief Find leaf index of which record belongs by features * \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record * \param feature_values Feature value of this record
...@@ -385,25 +371,25 @@ private: ...@@ -385,25 +371,25 @@ private:
/*! \brief Store the information for categorical feature handle and mising value handle. */ /*! \brief Store the information for categorical feature handle and mising value handle. */
std::vector<int8_t> decision_type_; std::vector<int8_t> decision_type_;
/*! \brief A non-leaf node's split gain */ /*! \brief A non-leaf node's split gain */
std::vector<double> split_gain_; std::vector<float> split_gain_;
// used for leaf node // used for leaf node
/*! \brief The parent of leaf */ /*! \brief The parent of leaf */
std::vector<int> leaf_parent_; std::vector<int> leaf_parent_;
/*! \brief Output of leaves */ /*! \brief Output of leaves */
std::vector<double> leaf_value_; std::vector<double> leaf_value_;
/*! \brief DataCount of leaves */ /*! \brief DataCount of leaves */
std::vector<data_size_t> leaf_count_; std::vector<int> leaf_count_;
/*! \brief Output of non-leaf nodes */ /*! \brief Output of non-leaf nodes */
std::vector<double> internal_value_; std::vector<double> internal_value_;
/*! \brief DataCount of non-leaf nodes */ /*! \brief DataCount of non-leaf nodes */
std::vector<data_size_t> internal_count_; std::vector<int> internal_count_;
/*! \brief Depth for leaves */ /*! \brief Depth for leaves */
std::vector<int> leaf_depth_; std::vector<int> leaf_depth_;
double shrinkage_; double shrinkage_;
}; };
inline void Tree::Split(int leaf, int feature, int real_feature, inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain) { double left_value, double right_value, int left_cnt, int right_cnt, float gain) {
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
// update parent info // update parent info
int parent = leaf_parent_[leaf]; int parent = leaf_parent_[leaf];
......
...@@ -17,11 +17,15 @@ ...@@ -17,11 +17,15 @@
#include <type_traits> #include <type_traits>
#include <iomanip> #include <iomanip>
#ifdef _MSC_VER
#include "intrin.h"
#endif
namespace LightGBM { namespace LightGBM {
namespace Common { namespace Common {
inline char tolower(char in) { inline static char tolower(char in) {
if (in <= 'Z' && in >= 'A') if (in <= 'Z' && in >= 'A')
return in - ('Z' - 'z'); return in - ('Z' - 'z');
return in; return in;
...@@ -128,18 +132,10 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli ...@@ -128,18 +132,10 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret; return ret;
} }
inline static std::string FindFromLines(const std::vector<std::string>& lines, const char* key_word) { template<typename T>
for (auto& line : lines) { inline static const char* Atoi(const char* p, T* out) {
size_t find_pos = line.find(key_word); int sign;
if (find_pos != std::string::npos) { T value;
return line;
}
}
return "";
}
inline static const char* Atoi(const char* p, int* out) {
int sign, value;
while (*p == ' ') { while (*p == ' ') {
++p; ++p;
} }
...@@ -153,14 +149,14 @@ inline static const char* Atoi(const char* p, int* out) { ...@@ -153,14 +149,14 @@ inline static const char* Atoi(const char* p, int* out) {
for (value = 0; *p >= '0' && *p <= '9'; ++p) { for (value = 0; *p >= '0' && *p <= '9'; ++p) {
value = value * 10 + (*p - '0'); value = value * 10 + (*p - '0');
} }
*out = sign * value; *out = static_cast<T>(sign * value);
while (*p == ' ') { while (*p == ' ') {
++p; ++p;
} }
return p; return p;
} }
template<class T> template<typename T>
inline static double Pow(T base, int power) { inline static double Pow(T base, int power) {
if (power < 0) { if (power < 0) {
return 1.0 / Pow(base, -power); return 1.0 / Pow(base, -power);
...@@ -267,7 +263,7 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -267,7 +263,7 @@ inline static const char* Atof(const char* p, double* out) {
return p; return p;
} }
inline bool AtoiAndCheck(const char* p, int* out) { inline static bool AtoiAndCheck(const char* p, int* out) {
const char* after = Atoi(p, out); const char* after = Atoi(p, out);
if (*after != '\0') { if (*after != '\0') {
return false; return false;
...@@ -275,7 +271,7 @@ inline bool AtoiAndCheck(const char* p, int* out) { ...@@ -275,7 +271,7 @@ inline bool AtoiAndCheck(const char* p, int* out) {
return true; return true;
} }
inline bool AtofAndCheck(const char* p, double* out) { inline static bool AtofAndCheck(const char* p, double* out) {
const char* after = Atof(p, out); const char* after = Atof(p, out);
if (*after != '\0') { if (*after != '\0') {
return false; return false;
...@@ -283,6 +279,97 @@ inline bool AtofAndCheck(const char* p, double* out) { ...@@ -283,6 +279,97 @@ inline bool AtofAndCheck(const char* p, double* out) {
return true; return true;
} }
inline static unsigned CountDecimalDigit32(uint32_t n) {
#if defined(_MSC_VER) || defined(__GNUC__)
static const uint32_t powers_of_10[] = {
0,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000
};
#ifdef _MSC_VER
unsigned long i = 0;
_BitScanReverse(&i, n | 1);
uint32_t t = (i + 1) * 1233 >> 12;
#elif __GNUC__
uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
#endif
return t - (n < powers_of_10[t]) + 1;
#else
if (n < 10) return 1;
if (n < 100) return 2;
if (n < 1000) return 3;
if (n < 10000) return 4;
if (n < 100000) return 5;
if (n < 1000000) return 6;
if (n < 10000000) return 7;
if (n < 100000000) return 8;
if (n < 1000000000) return 9;
return 10;
#endif
}
inline static void Uint32ToStr(uint32_t value, char* buffer) {
const char kDigitsLut[200] = {
'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
};
unsigned digit = CountDecimalDigit32(value);
buffer += digit;
*buffer = '\0';
while (value >= 100) {
const unsigned i = (value % 100) << 1;
value /= 100;
*--buffer = kDigitsLut[i + 1];
*--buffer = kDigitsLut[i];
}
if (value < 10) {
*--buffer = char(value) + '0';
}
else {
const unsigned i = value << 1;
*--buffer = kDigitsLut[i + 1];
*--buffer = kDigitsLut[i];
}
}
inline static void Int32ToStr(int32_t value, char* buffer) {
uint32_t u = static_cast<uint32_t>(value);
if (value < 0) {
*buffer++ = '-';
u = ~u + 1;
}
Uint32ToStr(u, buffer);
}
inline static void DoubleToStr(double value, char* buffer, size_t
#ifdef _MSC_VER
buffer_len
#endif
) {
#ifdef _MSC_VER
sprintf_s(buffer, buffer_len, "%.17g", value);
#else
sprintf(buffer, "%.17g", value);
#endif
}
inline static const char* SkipSpaceAndTab(const char* p) { inline static const char* SkipSpaceAndTab(const char* p) {
while (*p == ' ' || *p == '\t') { while (*p == ' ' || *p == '\t') {
++p; ++p;
...@@ -299,39 +386,72 @@ inline static const char* SkipReturn(const char* p) { ...@@ -299,39 +386,72 @@ inline static const char* SkipReturn(const char* p) {
template<typename T, typename T2> template<typename T, typename T2>
inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) { inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) {
std::vector<T2> ret; std::vector<T2> ret(arr.size());
for (size_t i = 0; i < arr.size(); ++i) { for (size_t i = 0; i < arr.size(); ++i) {
ret.push_back(static_cast<T2>(arr[i])); ret[i] = static_cast<T2>(arr[i]);
} }
return ret; return ret;
} }
template<typename T, bool is_float, bool is_unsign>
struct __TToStringHelperFast {
void operator()(T value, char* buffer, size_t ) const {
Int32ToStr(value, buffer);
}
};
template<typename T>
struct __TToStringHelperFast<T, true, false> {
void operator()(T value, char* buffer, size_t
#ifdef _MSC_VER
buf_len
#endif
) const {
#ifdef _MSC_VER
sprintf_s(buffer, buf_len, "%g", value);
#else
sprintf(buffer, "%g", value);
#endif
}
};
template<typename T> template<typename T>
inline static std::string ArrayToString(const std::vector<T>& arr, char delimiter) { struct __TToStringHelperFast<T, false, true> {
if (arr.empty()) { void operator()(T value, char* buffer, size_t ) const {
Uint32ToStr(value, buffer);
}
};
template<typename T>
inline static std::string ArrayToStringFast(const std::vector<T>& arr, size_t n) {
if (arr.empty() || n == 0) {
return std::string(""); return std::string("");
} }
__TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
const size_t buf_len = 16;
std::vector<char> buffer(buf_len);
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); helper(arr[0], buffer.data(), buf_len);
str_buf << arr[0]; str_buf << buffer.data();
for (size_t i = 1; i < arr.size(); ++i) { for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
str_buf << delimiter; helper(arr[i], buffer.data(), buf_len);
str_buf << arr[i]; str_buf << ' ' << buffer.data();
} }
return str_buf.str(); return str_buf.str();
} }
template<typename T> inline static std::string ArrayToString(const std::vector<double>& arr, size_t n) {
inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, char delimiter) {
if (arr.empty() || n == 0) { if (arr.empty() || n == 0) {
return std::string(""); return std::string("");
} }
const size_t buf_len = 32;
std::vector<char> buffer(buf_len);
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); DoubleToStr(arr[0], buffer.data(), buf_len);
str_buf << arr[0]; str_buf << buffer.data();
for (size_t i = 1; i < std::min(n, arr.size()); ++i) { for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
str_buf << delimiter; DoubleToStr(arr[i], buffer.data(), buf_len);
str_buf << arr[i]; str_buf << ' ' << buffer.data();
} }
return str_buf.str(); return str_buf.str();
} }
...@@ -339,7 +459,9 @@ inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, cha ...@@ -339,7 +459,9 @@ inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, cha
template<typename T, bool is_float> template<typename T, bool is_float>
struct __StringToTHelper { struct __StringToTHelper {
T operator()(const std::string& str) const { T operator()(const std::string& str) const {
return static_cast<T>(std::stoll(str)); T ret = 0;
Atoi(str.c_str(), &ret);
return ret;
} }
}; };
...@@ -351,25 +473,24 @@ struct __StringToTHelper<T, true> { ...@@ -351,25 +473,24 @@ struct __StringToTHelper<T, true> {
}; };
template<typename T> template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, char delimiter, size_t n) { inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
if (n == 0) {
return std::vector<T>();
}
std::vector<std::string> strs = Split(str.c_str(), delimiter); std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) { std::vector<T> ret;
Log::Fatal("StringToArray error, size doesn't match."); ret.reserve(strs.size());
}
std::vector<T> ret(n);
__StringToTHelper<T, std::is_floating_point<T>::value> helper; __StringToTHelper<T, std::is_floating_point<T>::value> helper;
for (size_t i = 0; i < n; ++i) { for (const auto& s : strs) {
ret[i] = helper(strs[i]); ret.push_back(helper(s));
} }
return ret; return ret;
} }
template<typename T> template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, char delimiter) { inline static std::vector<T> StringToArray(const std::string& str, int n) {
std::vector<std::string> strs = Split(str.c_str(), delimiter); if (n == 0) {
return std::vector<T>();
}
std::vector<std::string> strs = Split(str.c_str(), ' ');
CHECK(strs.size() == static_cast<size_t>(n));
std::vector<T> ret; std::vector<T> ret;
ret.reserve(strs.size()); ret.reserve(strs.size());
__StringToTHelper<T, std::is_floating_point<T>::value> helper; __StringToTHelper<T, std::is_floating_point<T>::value> helper;
...@@ -379,6 +500,37 @@ inline static std::vector<T> StringToArray(const std::string& str, char delimite ...@@ -379,6 +500,37 @@ inline static std::vector<T> StringToArray(const std::string& str, char delimite
return ret; return ret;
} }
template<typename T, bool is_float>
struct __StringToTHelperFast {
const char* operator()(const char*p, T* out) const {
return Atoi(p, out);
}
};
template<typename T>
struct __StringToTHelperFast<T, true> {
const char* operator()(const char*p, T* out) const {
double tmp = 0.0f;
auto ret = Atof(p, &tmp);
*out= static_cast<T>(tmp);
return ret;
}
};
template<typename T>
inline static std::vector<T> StringToArrayFast(const std::string& str, int n) {
if (n == 0) {
return std::vector<T>();
}
auto p_str = str.c_str();
__StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
std::vector<T> ret(n);
for (int i = 0; i < n; ++i) {
p_str = helper(p_str, &ret[i]);
}
return ret;
}
template<typename T> template<typename T>
inline static std::string Join(const std::vector<T>& strs, const char* delimiter) { inline static std::string Join(const std::vector<T>& strs, const char* delimiter) {
if (strs.empty()) { if (strs.empty()) {
...@@ -411,7 +563,7 @@ inline static std::string Join(const std::vector<T>& strs, size_t start, size_t ...@@ -411,7 +563,7 @@ inline static std::string Join(const std::vector<T>& strs, size_t start, size_t
return str_buf.str(); return str_buf.str();
} }
static inline int64_t Pow2RoundUp(int64_t x) { inline static int64_t Pow2RoundUp(int64_t x) {
int64_t t = 1; int64_t t = 1;
for (int i = 0; i < 64; ++i) { for (int i = 0; i < 64; ++i) {
if (t >= x) { if (t >= x) {
...@@ -426,7 +578,7 @@ static inline int64_t Pow2RoundUp(int64_t x) { ...@@ -426,7 +578,7 @@ static inline int64_t Pow2RoundUp(int64_t x) {
* \brief Do inplace softmax transformaton on p_rec * \brief Do inplace softmax transformaton on p_rec
* \param p_rec The input/output vector of the values. * \param p_rec The input/output vector of the values.
*/ */
inline void Softmax(std::vector<double>* p_rec) { inline static void Softmax(std::vector<double>* p_rec) {
std::vector<double> &rec = *p_rec; std::vector<double> &rec = *p_rec;
double wmax = rec[0]; double wmax = rec[0];
for (size_t i = 1; i < rec.size(); ++i) { for (size_t i = 1; i < rec.size(); ++i) {
...@@ -442,7 +594,7 @@ inline void Softmax(std::vector<double>* p_rec) { ...@@ -442,7 +594,7 @@ inline void Softmax(std::vector<double>* p_rec) {
} }
} }
inline void Softmax(const double* input, double* output, int len) { inline static void Softmax(const double* input, double* output, int len) {
double wmax = input[0]; double wmax = input[0];
for (int i = 1; i < len; ++i) { for (int i = 1; i < len; ++i) {
wmax = std::max(input[i], wmax); wmax = std::max(input[i], wmax);
...@@ -467,7 +619,7 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr< ...@@ -467,7 +619,7 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<
} }
template<typename T1, typename T2> template<typename T1, typename T2>
inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) { inline static void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
std::vector<std::pair<T1, T2>> arr; std::vector<std::pair<T1, T2>> arr;
for (size_t i = start; i < keys.size(); ++i) { for (size_t i = start; i < keys.size(); ++i) {
arr.emplace_back(keys[i], values[i]); arr.emplace_back(keys[i], values[i]);
...@@ -537,12 +689,22 @@ inline static double AvoidInf(double x) { ...@@ -537,12 +689,22 @@ inline static double AvoidInf(double x) {
} }
} }
template<class _Iter> inline inline static float AvoidInf(float x) {
if (x >= 1e38) {
return 1e38f;
} else if (x <= -1e38) {
return -1e38f;
} else {
return x;
}
}
template<typename _Iter> inline
static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) { static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
return (0); return (0);
} }
template<class _RanIt, class _Pr, class _VTRanIt> inline template<typename _RanIt, typename _Pr, typename _VTRanIt> inline
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) { static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
size_t len = _Last - _First; size_t len = _Last - _First;
const size_t kMinInnerLen = 1024; const size_t kMinInnerLen = 1024;
...@@ -589,14 +751,14 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) { ...@@ -589,14 +751,14 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
} }
} }
template<class _RanIt, class _Pr> inline template<typename _RanIt, typename _Pr> inline
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) { static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
return ParallelSort(_First, _Last, _Pred, IteratorValType(_First)); return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
} }
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not // Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
template <typename T> template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) { inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) { auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
std::ostringstream os; std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]"; os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
...@@ -627,7 +789,7 @@ inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, cons ...@@ -627,7 +789,7 @@ inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, cons
// One-pass scan over array w with nw elements: find min, max and sum of elements; // One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements. // this is useful for checking weight requirements.
template <typename T1, typename T2> template <typename T1, typename T2>
inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) { inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
T1 minw; T1 minw;
T1 maxw; T1 maxw;
T1 sumw; T1 sumw;
...@@ -669,8 +831,8 @@ inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) { ...@@ -669,8 +831,8 @@ inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
} }
} }
template<class T> template<typename T>
inline std::vector<uint32_t> ConstructBitset(const T* vals, int n) { inline static std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
std::vector<uint32_t> ret; std::vector<uint32_t> ret;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
int i1 = vals[i] / 32; int i1 = vals[i] / 32;
...@@ -683,8 +845,8 @@ inline std::vector<uint32_t> ConstructBitset(const T* vals, int n) { ...@@ -683,8 +845,8 @@ inline std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
return ret; return ret;
} }
template<class T> template<typename T>
inline bool FindInBitset(const uint32_t* bits, int n, T pos) { inline static bool FindInBitset(const uint32_t* bits, int n, T pos) {
int i1 = pos / 32; int i1 = pos / 32;
if (i1 >= n) { if (i1 >= n) {
return false; return false;
...@@ -702,6 +864,24 @@ inline static double GetDoubleUpperBound(double a) { ...@@ -702,6 +864,24 @@ inline static double GetDoubleUpperBound(double a) {
return std::nextafter(a, INFINITY);; return std::nextafter(a, INFINITY);;
} }
inline static size_t GetLine(const char* str) {
auto start = str;
while (*str != '\0' && *str != '\n' && *str != '\r') {
++str;
}
return str - start;
}
inline static const char* SkipNewLine(const char* str) {
if (*str == '\r') {
++str;
}
if (*str == '\n') {
++str;
}
return str;
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
...@@ -148,6 +148,31 @@ public: ...@@ -148,6 +148,31 @@ public:
}); });
} }
std::vector<char> ReadContent(size_t* out_len) {
std::vector<char> ret;
*out_len = 0;
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, filename_, "rb");
#else
file = fopen(filename_, "rb");
#endif
if (file == NULL) {
return ret;
}
const size_t buffer_size = 16 * 1024 * 1024;
auto buffer_read = std::vector<char>(buffer_size);
size_t read_cnt = 0;
do {
read_cnt = fread(buffer_read.data(), 1, buffer_size, file);
ret.insert(ret.end(), buffer_read.begin(), buffer_read.begin() + read_cnt);
*out_len += read_cnt;
} while (read_cnt > 0);
// close file
fclose(file);
return ret;
}
INDEX_T SampleFromFile(Random& random, INDEX_T sample_cnt, std::vector<std::string>* out_sampled_data) { INDEX_T SampleFromFile(Random& random, INDEX_T sample_cnt, std::vector<std::string>* out_sampled_data) {
INDEX_T cur_sample_cnt = 0; INDEX_T cur_sample_cnt = 0;
return ReadAllAndProcess( return ReadAllAndProcess(
......
PMML Generator PMML Generator
============== ==============
The script pmml.py can be used to translate the LightGBM models, found in LightGBM_model.txt, to predictive model markup language (PMML). These models can then be imported by other analytics applications. The models that the language can describe includes decision trees. The specification of PMML can be found here at the Data Mining Group's [website](http://dmg.org/pmml/v4-3/GeneralStructure.html). The old python convert script is removed due to it cannot support the new categorical features.
In order to generate pmml files do the following steps. Please move to https://github.com/jpmml/jpmml-lightgbm
```
lightgbm config=train.conf
python pmml.py LightGBM_model.txt
```
The python script will create a file called **LightGBM_pmml.xml**. Inside the file you will find a `MiningModel` tag. In there you will find `TreeModel` tags. Each `TreeModel` tag contains the pmml translation of a decision tree inside the LightGBM_model.txt file. The model described by the **LightGBM_pmml.xml** file can be transferred to other analytics applications. For instance you can use the pmml file as an input to the jpmml-evaluator API. Follow the steps below to run a model described by **LightGBM_pmml.xml**.
##### Steps to Run jpmml-evaluator
1. Clone the repository
```
git clone https://github.com/jpmml/jpmml-evaluator.git
```
2. Build using maven
```
mvn clean install
```
3. Run the EvaluationExample class on the model file using the following command
```
java -cp example-1.3-SNAPSHOT.jar org.jpmml.evaluator.EvaluationExample --model LightGBM_pmml.xml --input input.csv --output output.csv
```
Note, in order to run the model on the input.csv file, the input.csv file must have the same number of columns as specified by the `DataDictionary` field in the pmml file. Also, the column headers inside the input.csv file must be the same as the column names specified by the `MiningSchema` field. Inside output.csv you will find all the columns inside the input.csv file plus a new column. In the new column you will find the scores calculated by processing each rows data on the model. More information about jpmml-evaluator can be found at its [github repository](https://github.com/jpmml/jpmml-evaluator).
\ No newline at end of file
# coding: utf-8
# pylint: disable = C0111, C0103
"""convert LightGBM model to pmml"""
from __future__ import absolute_import
from sys import argv
from itertools import count
def get_value_string(line):
return line[line.find('=') + 1:]
def get_array_strings(line):
return get_value_string(line).split()
def get_array_ints(line):
return [int(token) for token in get_array_strings(line)]
def get_field_name(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id] if is_child else prev_node_idx
return feature_names[split_feature[idx]]
def get_threshold(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id] if is_child else prev_node_idx
return threshold[idx]
def print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf):
if is_left_child:
op = 'equal' if decision_type[prev_node_idx] == 1 else 'lessOrEqual'
else:
op = 'notEqual' if decision_type[prev_node_idx] == 1 else 'greaterThan'
out_('\t' * (tab_len + 1) + ("<SimplePredicate field=\"{0}\" " + " operator=\"{1}\" value=\"{2}\" />").format(
get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf)))
def print_nodes_pmml(node_id, tab_len, is_left_child, prev_node_idx):
if node_id < 0:
node_id = ~node_id
score = leaf_value[node_id]
recordCount = leaf_count[node_id]
is_leaf = True
else:
score = internal_value[node_id]
recordCount = internal_count[node_id]
is_leaf = False
out_('\t' * tab_len + ("<Node id=\"{0}\" score=\"{1}\" " + " recordCount=\"{2}\">").format(
next(unique_id), score, recordCount))
print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf)
if not is_leaf:
print_nodes_pmml(left_child[node_id], tab_len + 1, True, node_id)
print_nodes_pmml(right_child[node_id], tab_len + 1, False, node_id)
out_('\t' * tab_len + "</Node>")
# print out the pmml for a decision tree
def print_pmml():
# specify the objective as function name and binarySplit for
# splitCharacteristic because each node has 2 children
out_("\t\t\t\t<TreeModel functionName=\"regression\" splitCharacteristic=\"binarySplit\">")
out_("\t\t\t\t\t<MiningSchema>")
# list each feature name as a mining field, and treat all outliers as is,
# unless specified
for feature in feature_names:
out_("\t\t\t\t\t\t<MiningField name=\"%s\"/>" % (feature))
out_("\t\t\t\t\t</MiningSchema>")
# begin printing out the decision tree
out_("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
next(unique_id), internal_value[0], internal_count[0]))
out_("\t\t\t\t\t\t<True/>")
print_nodes_pmml(left_child[0], 6, True, 0)
print_nodes_pmml(right_child[0], 6, False, 0)
out_("\t\t\t\t\t</Node>")
out_("\t\t\t\t</TreeModel>")
if len(argv) != 2:
raise ValueError('usage: pmml.py <input model file>')
# open the model file and then process it
with open(argv[1], 'r') as model_in:
# ignore first 6 and empty lines
model_content = iter([line for line in model_in.read().splitlines() if line][6:])
feature_names = get_array_strings(next(model_content))
feature_infos = get_array_strings(next(model_content))
segment_id = count(1)
with open('LightGBM_pmml.xml', 'w') as pmml_out:
def out_(string):
pmml_out.write(string + '\n')
out_(
"<PMML version=\"4.3\" \n" +
"\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
"\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" +
"\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\">")
out_("\t<Header copyright=\"Microsoft\">")
out_("\t\t<Application name=\"LightGBM\"/>")
out_("\t</Header>")
# print out data dictionary entries for each column
out_("\t<DataDictionary numberOfFields=\"%d\">" % len(feature_names))
# not adding any interval definition, all values are currently
# valid
for feature in feature_names:
out_("\t\t<DataField name=\"" + feature + "\" optype=\"continuous\" dataType=\"double\"/>")
out_("\t</DataDictionary>")
out_("\t<MiningModel functionName=\"regression\">")
out_("\t\t<MiningSchema>")
# list each feature name as a mining field, and treat all outliers
# as is, unless specified
for feature in feature_names:
out_("\t\t\t<MiningField name=\"%s\"/>" % (feature))
out_("\t\t</MiningSchema>")
out_("\t\t<Segmentation multipleModelMethod=\"sum\">")
# read each array that contains pertinent information for the pmml
# these arrays will be used to recreate the traverse the decision tree
while True:
tree_start = next(model_content, '')
if not tree_start.startswith('Tree'):
break
out_("\t\t\t<Segment id=\"%d\">" % next(segment_id))
out_("\t\t\t\t<True/>")
tree_no = tree_start[5:]
num_leaves = int(get_value_string(next(model_content)))
split_feature = get_array_ints(next(model_content))
split_gain = next(model_content) # unused
threshold = get_array_strings(next(model_content))
decision_type = get_array_ints(next(model_content))
left_child = get_array_ints(next(model_content))
right_child = get_array_ints(next(model_content))
leaf_parent = get_array_ints(next(model_content))
leaf_value = get_array_strings(next(model_content))
leaf_count = get_array_strings(next(model_content))
internal_value = get_array_strings(next(model_content))
internal_count = get_array_strings(next(model_content))
shrinkage = get_value_string(next(model_content))
has_categorical = get_value_string(next(model_content))
unique_id = count(1)
print_pmml()
out_("\t\t\t</Segment>")
out_("\t\t</Segmentation>")
out_("\t</MiningModel>")
out_("</PMML>")
syntax = "proto3";
package LightGBM;
message Model {
string name = 1;
uint32 num_class = 2;
uint32 num_tree_per_iteration = 3;
uint32 label_index = 4;
uint32 max_feature_idx = 5;
string objective = 6;
bool average_output = 7;
repeated string feature_names = 8;
repeated string feature_infos = 9;
message Tree {
uint32 num_leaves = 1;
uint32 num_cat = 2;
repeated uint32 split_feature = 3;
repeated double split_gain = 4;
repeated double threshold = 5;
repeated uint32 decision_type = 6;
repeated sint32 left_child = 7;
repeated sint32 right_child = 8;
repeated double leaf_value = 9;
repeated uint32 leaf_count = 10;
repeated double internal_value = 11;
repeated double internal_count = 12;
repeated sint32 cat_boundaries = 13;
repeated uint32 cat_threshold = 14;
double shrinkage = 15;
}
repeated Tree trees = 10;
}
...@@ -555,7 +555,7 @@ class _InnerPredictor(object): ...@@ -555,7 +555,7 @@ class _InnerPredictor(object):
class Dataset(object): class Dataset(object):
"""Dataset in LightGBM.""" """Dataset in LightGBM."""
def __init__(self, data, label=None, max_bin=None, reference=None, def __init__(self, data, label=None, reference=None,
weight=None, group=None, init_score=None, silent=False, weight=None, group=None, init_score=None, silent=False,
feature_name='auto', categorical_feature='auto', params=None, feature_name='auto', categorical_feature='auto', params=None,
free_raw_data=True): free_raw_data=True):
...@@ -568,9 +568,6 @@ class Dataset(object): ...@@ -568,9 +568,6 @@ class Dataset(object):
If string, it represents the path to txt file. If string, it represents the path to txt file.
label : list, numpy 1-D array or None, optional (default=None) label : list, numpy 1-D array or None, optional (default=None)
Label of the data. Label of the data.
max_bin : int or None, optional (default=None)
Max number of discrete bins for features.
If None, default value from parameters of CLI-version will be used.
reference : Dataset or None, optional (default=None) reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference. If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array or None, optional (default=None) weight : list, numpy 1-D array or None, optional (default=None)
...@@ -597,7 +594,6 @@ class Dataset(object): ...@@ -597,7 +594,6 @@ class Dataset(object):
self.handle = None self.handle = None
self.data = data self.data = data
self.label = label self.label = label
self.max_bin = max_bin
self.reference = reference self.reference = reference
self.weight = weight self.weight = weight
self.group = group self.group = group
...@@ -620,7 +616,7 @@ class Dataset(object): ...@@ -620,7 +616,7 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetFree(self.handle)) _safe_call(_LIB.LGBM_DatasetFree(self.handle))
self.handle = None self.handle = None
def _lazy_init(self, data, label=None, max_bin=None, reference=None, def _lazy_init(self, data, label=None, reference=None,
weight=None, group=None, init_score=None, predictor=None, weight=None, group=None, init_score=None, predictor=None,
silent=False, feature_name='auto', silent=False, feature_name='auto',
categorical_feature='auto', params=None): categorical_feature='auto', params=None):
...@@ -640,12 +636,7 @@ class Dataset(object): ...@@ -640,12 +636,7 @@ class Dataset(object):
if key in args_names: if key in args_names:
warnings.warn('{0} keyword has been found in `params` and will be ignored. ' warnings.warn('{0} keyword has been found in `params` and will be ignored. '
'Please use {0} argument of the Dataset constructor to pass this parameter.'.format(key)) 'Please use {0} argument of the Dataset constructor to pass this parameter.'.format(key))
self.max_bin = max_bin
self.predictor = predictor self.predictor = predictor
if self.max_bin is not None:
params["max_bin"] = self.max_bin
warnings.warn('The `max_bin` parameter is deprecated and will be removed in 2.0.12 version. '
'Please use `params` to pass this parameter.', LGBMDeprecationWarning)
if "verbosity" in params: if "verbosity" in params:
params.setdefault("verbose", params.pop("verbosity")) params.setdefault("verbose", params.pop("verbosity"))
if silent: if silent:
...@@ -821,7 +812,7 @@ class Dataset(object): ...@@ -821,7 +812,7 @@ class Dataset(object):
if self.reference is not None: if self.reference is not None:
if self.used_indices is None: if self.used_indices is None:
# create valid # create valid
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference, self._lazy_init(self.data, label=self.label, reference=self.reference,
weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor, weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor,
silent=self.silent, feature_name=self.feature_name, params=self.params) silent=self.silent, feature_name=self.feature_name, params=self.params)
else: else:
...@@ -839,7 +830,7 @@ class Dataset(object): ...@@ -839,7 +830,7 @@ class Dataset(object):
raise ValueError("Label should not be None.") raise ValueError("Label should not be None.")
else: else:
# create train # create train
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, self._lazy_init(self.data, label=self.label,
weight=self.weight, group=self.group, init_score=self.init_score, weight=self.weight, group=self.group, init_score=self.init_score,
predictor=self._predictor, silent=self.silent, feature_name=self.feature_name, predictor=self._predictor, silent=self.silent, feature_name=self.feature_name,
categorical_feature=self.categorical_feature, params=self.params) categorical_feature=self.categorical_feature, params=self.params)
...@@ -874,7 +865,7 @@ class Dataset(object): ...@@ -874,7 +865,7 @@ class Dataset(object):
self : Dataset self : Dataset
Returns self. Returns self.
""" """
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self, ret = Dataset(data, label=label, reference=self,
weight=weight, group=group, init_score=init_score, weight=weight, group=group, init_score=init_score,
silent=silent, params=params, free_raw_data=self.free_raw_data) silent=silent, params=params, free_raw_data=self.free_raw_data)
ret._predictor = self._predictor ret._predictor = self._predictor
......
...@@ -133,7 +133,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -133,7 +133,7 @@ class LGBMModel(_LGBMModelBase):
"""Implementation of the scikit-learn API for LightGBM.""" """Implementation of the scikit-learn API for LightGBM."""
def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1, def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255, learning_rate=0.1, n_estimators=10,
subsample_for_bin=200000, objective=None, subsample_for_bin=200000, objective=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=1, colsample_bytree=1., subsample=1., subsample_freq=1, colsample_bytree=1.,
...@@ -156,8 +156,6 @@ class LGBMModel(_LGBMModelBase): ...@@ -156,8 +156,6 @@ class LGBMModel(_LGBMModelBase):
Boosting learning rate. Boosting learning rate.
n_estimators : int, optional (default=10) n_estimators : int, optional (default=10)
Number of boosted trees to fit. Number of boosted trees to fit.
max_bin : int, optional (default=255)
Number of bucketed bins for feature values.
subsample_for_bin : int, optional (default=50000) subsample_for_bin : int, optional (default=50000)
Number of samples for constructing bins. Number of samples for constructing bins.
objective : string, callable or None, optional (default=None) objective : string, callable or None, optional (default=None)
...@@ -246,7 +244,6 @@ class LGBMModel(_LGBMModelBase): ...@@ -246,7 +244,6 @@ class LGBMModel(_LGBMModelBase):
self.max_depth = max_depth self.max_depth = max_depth
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_estimators = n_estimators self.n_estimators = n_estimators
self.max_bin = max_bin
self.subsample_for_bin = subsample_for_bin self.subsample_for_bin = subsample_for_bin
self.min_split_gain = min_split_gain self.min_split_gain = min_split_gain
self.min_child_weight = min_child_weight self.min_child_weight = min_child_weight
...@@ -410,7 +407,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -410,7 +407,7 @@ class LGBMModel(_LGBMModelBase):
self._n_features = X.shape[1] self._n_features = X.shape[1]
def _construct_dataset(X, y, sample_weight, init_score, group, params): def _construct_dataset(X, y, sample_weight, init_score, group, params):
ret = Dataset(X, label=y, max_bin=self.max_bin, weight=sample_weight, group=group, params=params) ret = Dataset(X, label=y, weight=sample_weight, group=group, params=params)
ret.set_init_score(init_score) ret.set_init_score(init_score)
return ret return ret
......
...@@ -180,7 +180,6 @@ void Application::InitTrain() { ...@@ -180,7 +180,6 @@ void Application::InitTrain() {
// create boosting // create boosting
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str())); config_.io_config.input_model.c_str()));
// create objective function // create objective function
objective_fun_.reset( objective_fun_.reset(
...@@ -204,26 +203,7 @@ void Application::InitTrain() { ...@@ -204,26 +203,7 @@ void Application::InitTrain() {
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model); boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
std::vector<std::string> model_formats = Common::Split(config_.io_config.model_format.c_str(), ','); boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
bool save_with_multiple_format = (model_formats.size() > 1);
for (auto model_format: model_formats) {
std::string save_file_name = config_.io_config.output_model;
if (save_with_multiple_format) {
// use suffix to distinguish different model format
save_file_name += "." + model_format;
}
if (model_format == std::string("text")) {
boosting_->SaveModelToFile(-1, save_file_name.c_str());
} else if (model_format == std::string("proto")) {
#ifdef USE_PROTO
boosting_->SaveModelToProto(-1, save_file_name.c_str());
#else
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
#endif // USE_PROTO
} else {
Log::Fatal("Unknown model format during saving: %s", model_format.c_str());
}
}
// convert model to if-else statement code // convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) { if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
...@@ -244,16 +224,13 @@ void Application::Predict() { ...@@ -244,16 +224,13 @@ void Application::Predict() {
void Application::InitPredict() { void Application::InitPredict() {
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting("gbdt", config_.io_config.model_format.c_str(), Boosting::CreateBoosting("gbdt", config_.io_config.input_model.c_str()));
config_.io_config.input_model.c_str()));
Log::Info("Finished initializing prediction"); Log::Info("Finished initializing prediction");
} }
void Application::ConvertModel() { void Application::ConvertModel() {
boosting_.reset( boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type, Boosting::CreateBoosting(config_.boosting_type, config_.io_config.input_model.c_str()));
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
} }
......
...@@ -12,34 +12,22 @@ std::string GetBoostingTypeFromModelFile(const char* filename) { ...@@ -12,34 +12,22 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return type; return type;
} }
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) { bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
auto start_time = std::chrono::steady_clock::now();
if (boosting != nullptr) { if (boosting != nullptr) {
if (format == std::string("text")) {
TextReader<size_t> model_reader(filename, true); TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines(); size_t buffer_len = 0;
std::stringstream str_buf; auto buffer = model_reader.ReadContent(&buffer_len);
for (auto& line : model_reader.Lines()) { if (!boosting->LoadModelFromString(buffer.data(), buffer_len)) {
str_buf << line << '\n';
}
if (!boosting->LoadModelFromString(str_buf.str())) {
return false;
}
} else if (format == std::string("proto")) {
#ifdef USE_PROTO
if (!boosting->LoadModelFromProto(filename)) {
return false; return false;
} }
#else
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
#endif // USE_PROTO
} else {
Log::Fatal("Unknown model format during loading: %s", format.c_str());
}
} }
std::chrono::duration<double, std::milli> delta = (std::chrono::steady_clock::now() - start_time);
Log::Info("time for loading model: %f seconds", 1e-3*delta);
return true; return true;
} }
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) { Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
if (filename == nullptr || filename[0] == '\0') { if (filename == nullptr || filename[0] == '\0') {
if (type == std::string("gbdt")) { if (type == std::string("gbdt")) {
return new GBDT(); return new GBDT();
...@@ -54,7 +42,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& f ...@@ -54,7 +42,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& f
} }
} else { } else {
std::unique_ptr<Boosting> ret; std::unique_ptr<Boosting> ret;
if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) { if (GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
if (type == std::string("gbdt")) { if (type == std::string("gbdt")) {
ret.reset(new GBDT()); ret.reset(new GBDT());
} else if (type == std::string("dart")) { } else if (type == std::string("dart")) {
...@@ -66,7 +54,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& f ...@@ -66,7 +54,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& f
} else { } else {
Log::Fatal("unknown boosting type %s", type.c_str()); Log::Fatal("unknown boosting type %s", type.c_str());
} }
LoadFileToBoosting(ret.get(), format, filename); LoadFileToBoosting(ret.get(), filename);
} else { } else {
Log::Fatal("unknown model format or submodel type in model file %s", filename); Log::Fatal("unknown model format or submodel type in model file %s", filename);
} }
......
...@@ -588,7 +588,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -588,7 +588,7 @@ std::string GBDT::OutputMetric(int iter) {
<< " : " << scores[k]; << " : " << scores[k];
Log::Info(tmp_buf.str().c_str()); Log::Info(tmp_buf.str().c_str());
if (early_stopping_round_ > 0) { if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl; msg_buf << tmp_buf.str() << '\n';
} }
} }
} }
...@@ -608,7 +608,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -608,7 +608,7 @@ std::string GBDT::OutputMetric(int iter) {
Log::Info(tmp_buf.str().c_str()); Log::Info(tmp_buf.str().c_str());
} }
if (early_stopping_round_ > 0) { if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl; msg_buf << tmp_buf.str() << '\n';
} }
} }
if (ret.empty() && early_stopping_round_ > 0) { if (ret.empty() && early_stopping_round_ > 0) {
......
...@@ -241,25 +241,9 @@ public: ...@@ -241,25 +241,9 @@ public:
virtual std::string SaveModelToString(int num_iterations) const override; virtual std::string SaveModelToString(int num_iterations) const override;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized buffer
*/ */
bool LoadModelFromString(const std::string& model_str) override; bool LoadModelFromString(const char* buffer, size_t len) override;
#ifdef USE_PROTO
/*!
* \brief Save model with protobuf
* \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to
*/
void SaveModelToProto(int num_iteration, const char* filename) const override;
/*!
* \brief Restore from a serialized protobuf file
* \param filename Filename that want to restore from
* \return true if succeeded
*/
bool LoadModelFromProto(const char* filename) override;
#endif // USE_PROTO
/*! /*!
* \brief Calculate feature importances * \brief Calculate feature importances
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment