"csrc/vscode:/vscode.git/clone" did not exist on "cf5cb1e33eed16b2f0d5fe6268bf5705a4d0ea5a"
Unverified Commit 8a5ec366 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Speed up saving and loading model (#1083)

* remove protobuf

* add version number

* remove pmml script

* use float for split gain

* fix warnings

* refine the read model logic of gbdt

* fix compile error

* improve decode speed

* fix some bugs

* fix double accuracy problem

* fix bug

* multi-thread save model

* speed up save model to string

* parallel save/load model

* fix some warnings.

* fix warnings.

* fix a bug

* remove debug output

* fix doc

* fix max_bin warning in tests.

* fix max_bin warning

* fix pylint

* clean code for stringToArray

* clean code for TToString

* remove max_bin

* replace "class" with typename
parent 8d016c12
......@@ -23,7 +23,6 @@ env:
- TASK=if-else
- TASK=sdist PYTHON_VERSION=3.4
- TASK=bdist PYTHON_VERSION=3.5
- TASK=proto
- TASK=gpu METHOD=source
- TASK=gpu METHOD=pip
......@@ -39,8 +38,6 @@ matrix:
env: TASK=pylint
- os: osx
env: TASK=check-docs
- os: osx
env: TASK=proto
before_install:
- test -n $CC && unset CC
......
......@@ -62,18 +62,6 @@ if [[ ${TASK} == "if-else" ]]; then
exit 0
fi
if [[ ${TASK} == "proto" ]]; then
conda install numpy
source activate test-env
mkdir build && cd build && cmake .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR && git clone https://github.com/google/protobuf && cd protobuf && ./autogen.sh && ./configure && make && sudo make install && sudo ldconfig
cd $TRAVIS_BUILD_DIR/build && rm -rf * && cmake -DUSE_PROTO=ON .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf model_format=proto && ../../lightgbm config=predict.conf output_result=proto.pred model_format=proto || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && python test.py || exit -1
exit 0
fi
conda install numpy nose scipy scikit-learn pandas matplotlib pytest
if [[ ${TASK} == "sdist" ]]; then
......
......@@ -124,25 +124,8 @@ file(GLOB SOURCES
src/treelearner/*.cpp
)
if (USE_PROTO)
if(MSVC)
message(FATAL_ERROR "Cannot use proto with MSVC.")
endif(MSVC)
find_package(Protobuf REQUIRED)
PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS proto/model.proto)
include_directories(${PROTOBUF_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
ADD_DEFINITIONS(-DUSE_PROTO)
SET(PROTO_FILES src/proto/gbdt_model_proto.cpp ${PROTO_HDRS} ${PROTO_SRCS})
endif(USE_PROTO)
add_executable(lightgbm src/main.cpp ${SOURCES} ${PROTO_FILES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES} ${PROTO_FILES})
if (USE_PROTO)
TARGET_LINK_LIBRARIES(lightgbm ${PROTOBUF_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${PROTOBUF_LIBRARIES})
endif(USE_PROTO)
add_executable(lightgbm src/main.cpp ${SOURCES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES})
if(MSVC)
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm")
......
......@@ -271,21 +271,6 @@ Following procedure is for the MSVC (Microsoft Visual C++) build.
**Note**: ``C:\local\boost_1_64_0\`` and ``C:\local\boost_1_64_0\lib64-msvc-14.0`` are locations of your Boost binaries. You also can set them to the environment variable to avoid ``Set ...`` commands when build.
Protobuf Support
^^^^^^^^^^^^^^^^
If you want to use protobuf to save and load models, install `protobuf c++ version <https://github.com/google/protobuf/blob/master/src/README.md>`__ first.
Then run cmake with USE_PROTO on, for example:
.. code::
cmake -DUSE_PROTO=ON ..
You can then use ``model_format=proto`` in parameters when save and load models.
**Note**: for windows user, it's only tested with mingw.
Docker
^^^^^^
......
......@@ -335,20 +335,6 @@ IO Parameters
- file name of prediction result in ``prediction`` task
- ``model_format``, default=\ ``text``, type=multi-enum, options=\ ``text``, ``proto``
- format to save and load model
- if ``text``, text string will be used
- if ``proto``, Protocol Buffer binary format will be used
- you can save in multiple formats by joining them with comma, like ``text,proto``. In this case, ``model_format`` will be add as suffix after ``output_model``
- **Note**: loading with multiple formats is not supported
- **Note**: to use this parameter you need to `build version with Protobuf Support <./Installation-Guide.rst#protobuf-support>`__
- ``pre_partition``, default=\ ``false``, type=bool, alias=\ ``is_pre_partition``
- used for parallel learning (not include feature parallel)
......
......@@ -4,10 +4,6 @@
#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#ifdef USE_PROTO
#include "model.pb.h"
#endif // USE_PROTO
#include <vector>
#include <string>
#include <map>
......@@ -198,26 +194,11 @@ public:
/*!
* \brief Restore from a serialized string
* \param model_str The string of model
* \return true if succeeded
*/
virtual bool LoadModelFromString(const std::string& model_str) = 0;
#ifdef USE_PROTO
/*!
* \brief Save model with protobuf
* \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to
*/
virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0;
/*!
* \brief Restore from a serialized protobuf file
* \param filename Filename that want to restore from
* \param buffer The content of model
* \param len The length of buffer
* \return true if succeeded
*/
virtual bool LoadModelFromProto(const char* filename) = 0;
#endif // USE_PROTO
virtual bool LoadModelFromString(const char* buffer, size_t len) = 0;
/*!
* \brief Calculate feature importances
......@@ -283,7 +264,7 @@ public:
/*! \brief Disable copy */
Boosting(const Boosting&) = delete;
static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename);
static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
/*!
* \brief Create boosting object
......@@ -293,7 +274,7 @@ public:
* \param filename name of model file, if existing will continue to train from this model
* \return The boosting object
*/
static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename);
static Boosting* CreateBoosting(const std::string& type, const char* filename);
};
......
......@@ -105,7 +105,6 @@ public:
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
std::string input_model = "";
std::string model_format = "text";
int verbosity = 1;
int num_iteration_predict = -1;
bool is_pre_partition = false;
......@@ -449,7 +448,7 @@ struct ParameterAlias {
const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device",
"num_threads", "seed", "boosting_type", "objective", "data",
"output_model", "input_model", "output_result", "model_format", "valid_data",
"output_model", "input_model", "output_result", "valid_data",
"is_enable_sparse", "is_pre_partition", "is_training_metric",
"ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf",
"num_leaves", "feature_fraction", "num_iterations",
......
......@@ -3,9 +3,6 @@
#include <LightGBM/meta.h>
#include <LightGBM/dataset.h>
#ifdef USE_PROTO
#include "model.pb.h"
#endif // USE_PROTO
#include <string>
#include <vector>
......@@ -32,15 +29,9 @@ public:
/*!
* \brief Construtor, from a string
* \param str Model string
* \param used_len used count of str
*/
explicit Tree(const std::string& str);
#ifdef USE_PROTO
/*!
* \brief Construtor, from a protobuf object
* \param model_tree Model protobuf object
*/
explicit Tree(const Model_Tree& model_tree);
#endif // USE_PROTO
Tree(const char* str, size_t* used_len);
~Tree();
......@@ -62,7 +53,7 @@ public:
*/
int Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
double threshold_double, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type, bool default_left);
int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left);
/*!
* \brief Performing a split on tree leaves, with categorical feature
......@@ -82,7 +73,7 @@ public:
*/
int SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
const uint32_t* threshold, int num_threshold, double left_value, double right_value,
data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type);
int left_cnt, int right_cnt, float gain, MissingType missing_type);
/*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
......@@ -179,11 +170,6 @@ public:
/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const;
#ifdef USE_PROTO
/*! \brief Serialize this object to protobuf object*/
void ToProto(Model_Tree& model_tree) const;
#endif // USE_PROTO
inline static bool IsZero(double fval) {
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
return true;
......@@ -304,7 +290,7 @@ private:
}
inline void Split(int leaf, int feature, int real_feature,
double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain);
double left_value, double right_value, int left_cnt, int right_cnt, float gain);
/*!
* \brief Find leaf index of which record belongs by features
* \param feature_values Feature value of this record
......@@ -385,25 +371,25 @@ private:
/*! \brief Store the information for categorical feature handle and mising value handle. */
std::vector<int8_t> decision_type_;
/*! \brief A non-leaf node's split gain */
std::vector<double> split_gain_;
std::vector<float> split_gain_;
// used for leaf node
/*! \brief The parent of leaf */
std::vector<int> leaf_parent_;
/*! \brief Output of leaves */
std::vector<double> leaf_value_;
/*! \brief DataCount of leaves */
std::vector<data_size_t> leaf_count_;
std::vector<int> leaf_count_;
/*! \brief Output of non-leaf nodes */
std::vector<double> internal_value_;
/*! \brief DataCount of non-leaf nodes */
std::vector<data_size_t> internal_count_;
std::vector<int> internal_count_;
/*! \brief Depth for leaves */
std::vector<int> leaf_depth_;
double shrinkage_;
};
inline void Tree::Split(int leaf, int feature, int real_feature,
double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain) {
double left_value, double right_value, int left_cnt, int right_cnt, float gain) {
int new_node_idx = num_leaves_ - 1;
// update parent info
int parent = leaf_parent_[leaf];
......
......@@ -17,11 +17,15 @@
#include <type_traits>
#include <iomanip>
#ifdef _MSC_VER
#include "intrin.h"
#endif
namespace LightGBM {
namespace Common {
inline char tolower(char in) {
inline static char tolower(char in) {
if (in <= 'Z' && in >= 'A')
return in - ('Z' - 'z');
return in;
......@@ -128,18 +132,10 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret;
}
inline static std::string FindFromLines(const std::vector<std::string>& lines, const char* key_word) {
for (auto& line : lines) {
size_t find_pos = line.find(key_word);
if (find_pos != std::string::npos) {
return line;
}
}
return "";
}
inline static const char* Atoi(const char* p, int* out) {
int sign, value;
template<typename T>
inline static const char* Atoi(const char* p, T* out) {
int sign;
T value;
while (*p == ' ') {
++p;
}
......@@ -153,14 +149,14 @@ inline static const char* Atoi(const char* p, int* out) {
for (value = 0; *p >= '0' && *p <= '9'; ++p) {
value = value * 10 + (*p - '0');
}
*out = sign * value;
*out = static_cast<T>(sign * value);
while (*p == ' ') {
++p;
}
return p;
}
template<class T>
template<typename T>
inline static double Pow(T base, int power) {
if (power < 0) {
return 1.0 / Pow(base, -power);
......@@ -267,7 +263,7 @@ inline static const char* Atof(const char* p, double* out) {
return p;
}
inline bool AtoiAndCheck(const char* p, int* out) {
inline static bool AtoiAndCheck(const char* p, int* out) {
const char* after = Atoi(p, out);
if (*after != '\0') {
return false;
......@@ -275,7 +271,7 @@ inline bool AtoiAndCheck(const char* p, int* out) {
return true;
}
inline bool AtofAndCheck(const char* p, double* out) {
inline static bool AtofAndCheck(const char* p, double* out) {
const char* after = Atof(p, out);
if (*after != '\0') {
return false;
......@@ -283,6 +279,97 @@ inline bool AtofAndCheck(const char* p, double* out) {
return true;
}
inline static unsigned CountDecimalDigit32(uint32_t n) {
#if defined(_MSC_VER) || defined(__GNUC__)
static const uint32_t powers_of_10[] = {
0,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000
};
#ifdef _MSC_VER
unsigned long i = 0;
_BitScanReverse(&i, n | 1);
uint32_t t = (i + 1) * 1233 >> 12;
#elif __GNUC__
uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
#endif
return t - (n < powers_of_10[t]) + 1;
#else
if (n < 10) return 1;
if (n < 100) return 2;
if (n < 1000) return 3;
if (n < 10000) return 4;
if (n < 100000) return 5;
if (n < 1000000) return 6;
if (n < 10000000) return 7;
if (n < 100000000) return 8;
if (n < 1000000000) return 9;
return 10;
#endif
}
inline static void Uint32ToStr(uint32_t value, char* buffer) {
const char kDigitsLut[200] = {
'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
};
unsigned digit = CountDecimalDigit32(value);
buffer += digit;
*buffer = '\0';
while (value >= 100) {
const unsigned i = (value % 100) << 1;
value /= 100;
*--buffer = kDigitsLut[i + 1];
*--buffer = kDigitsLut[i];
}
if (value < 10) {
*--buffer = char(value) + '0';
}
else {
const unsigned i = value << 1;
*--buffer = kDigitsLut[i + 1];
*--buffer = kDigitsLut[i];
}
}
inline static void Int32ToStr(int32_t value, char* buffer) {
uint32_t u = static_cast<uint32_t>(value);
if (value < 0) {
*buffer++ = '-';
u = ~u + 1;
}
Uint32ToStr(u, buffer);
}
inline static void DoubleToStr(double value, char* buffer, size_t
#ifdef _MSC_VER
buffer_len
#endif
) {
#ifdef _MSC_VER
sprintf_s(buffer, buffer_len, "%.17g", value);
#else
sprintf(buffer, "%.17g", value);
#endif
}
inline static const char* SkipSpaceAndTab(const char* p) {
while (*p == ' ' || *p == '\t') {
++p;
......@@ -299,39 +386,72 @@ inline static const char* SkipReturn(const char* p) {
template<typename T, typename T2>
inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) {
std::vector<T2> ret;
std::vector<T2> ret(arr.size());
for (size_t i = 0; i < arr.size(); ++i) {
ret.push_back(static_cast<T2>(arr[i]));
ret[i] = static_cast<T2>(arr[i]);
}
return ret;
}
template<typename T, bool is_float, bool is_unsign>
struct __TToStringHelperFast {
void operator()(T value, char* buffer, size_t ) const {
Int32ToStr(value, buffer);
}
};
template<typename T>
struct __TToStringHelperFast<T, true, false> {
void operator()(T value, char* buffer, size_t
#ifdef _MSC_VER
buf_len
#endif
) const {
#ifdef _MSC_VER
sprintf_s(buffer, buf_len, "%g", value);
#else
sprintf(buffer, "%g", value);
#endif
}
};
template<typename T>
inline static std::string ArrayToString(const std::vector<T>& arr, char delimiter) {
if (arr.empty()) {
struct __TToStringHelperFast<T, false, true> {
void operator()(T value, char* buffer, size_t ) const {
Uint32ToStr(value, buffer);
}
};
template<typename T>
inline static std::string ArrayToStringFast(const std::vector<T>& arr, size_t n) {
if (arr.empty() || n == 0) {
return std::string("");
}
__TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
const size_t buf_len = 16;
std::vector<char> buffer(buf_len);
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
str_buf << delimiter;
str_buf << arr[i];
helper(arr[0], buffer.data(), buf_len);
str_buf << buffer.data();
for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
helper(arr[i], buffer.data(), buf_len);
str_buf << ' ' << buffer.data();
}
return str_buf.str();
}
template<typename T>
inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, char delimiter) {
inline static std::string ArrayToString(const std::vector<double>& arr, size_t n) {
if (arr.empty() || n == 0) {
return std::string("");
}
const size_t buf_len = 32;
std::vector<char> buffer(buf_len);
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << arr[0];
DoubleToStr(arr[0], buffer.data(), buf_len);
str_buf << buffer.data();
for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
str_buf << delimiter;
str_buf << arr[i];
DoubleToStr(arr[i], buffer.data(), buf_len);
str_buf << ' ' << buffer.data();
}
return str_buf.str();
}
......@@ -339,7 +459,9 @@ inline static std::string ArrayToString(const std::vector<T>& arr, size_t n, cha
template<typename T, bool is_float>
struct __StringToTHelper {
T operator()(const std::string& str) const {
return static_cast<T>(std::stoll(str));
T ret = 0;
Atoi(str.c_str(), &ret);
return ret;
}
};
......@@ -351,25 +473,24 @@ struct __StringToTHelper<T, true> {
};
template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, char delimiter, size_t n) {
if (n == 0) {
return std::vector<T>();
}
inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
std::vector<std::string> strs = Split(str.c_str(), delimiter);
if (strs.size() != n) {
Log::Fatal("StringToArray error, size doesn't match.");
}
std::vector<T> ret(n);
std::vector<T> ret;
ret.reserve(strs.size());
__StringToTHelper<T, std::is_floating_point<T>::value> helper;
for (size_t i = 0; i < n; ++i) {
ret[i] = helper(strs[i]);
for (const auto& s : strs) {
ret.push_back(helper(s));
}
return ret;
}
template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
std::vector<std::string> strs = Split(str.c_str(), delimiter);
inline static std::vector<T> StringToArray(const std::string& str, int n) {
if (n == 0) {
return std::vector<T>();
}
std::vector<std::string> strs = Split(str.c_str(), ' ');
CHECK(strs.size() == static_cast<size_t>(n));
std::vector<T> ret;
ret.reserve(strs.size());
__StringToTHelper<T, std::is_floating_point<T>::value> helper;
......@@ -379,6 +500,37 @@ inline static std::vector<T> StringToArray(const std::string& str, char delimite
return ret;
}
template<typename T, bool is_float>
struct __StringToTHelperFast {
const char* operator()(const char*p, T* out) const {
return Atoi(p, out);
}
};
template<typename T>
struct __StringToTHelperFast<T, true> {
const char* operator()(const char*p, T* out) const {
double tmp = 0.0f;
auto ret = Atof(p, &tmp);
*out= static_cast<T>(tmp);
return ret;
}
};
template<typename T>
inline static std::vector<T> StringToArrayFast(const std::string& str, int n) {
if (n == 0) {
return std::vector<T>();
}
auto p_str = str.c_str();
__StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
std::vector<T> ret(n);
for (int i = 0; i < n; ++i) {
p_str = helper(p_str, &ret[i]);
}
return ret;
}
template<typename T>
inline static std::string Join(const std::vector<T>& strs, const char* delimiter) {
if (strs.empty()) {
......@@ -411,7 +563,7 @@ inline static std::string Join(const std::vector<T>& strs, size_t start, size_t
return str_buf.str();
}
static inline int64_t Pow2RoundUp(int64_t x) {
inline static int64_t Pow2RoundUp(int64_t x) {
int64_t t = 1;
for (int i = 0; i < 64; ++i) {
if (t >= x) {
......@@ -426,7 +578,7 @@ static inline int64_t Pow2RoundUp(int64_t x) {
* \brief Do inplace softmax transformaton on p_rec
* \param p_rec The input/output vector of the values.
*/
inline void Softmax(std::vector<double>* p_rec) {
inline static void Softmax(std::vector<double>* p_rec) {
std::vector<double> &rec = *p_rec;
double wmax = rec[0];
for (size_t i = 1; i < rec.size(); ++i) {
......@@ -442,7 +594,7 @@ inline void Softmax(std::vector<double>* p_rec) {
}
}
inline void Softmax(const double* input, double* output, int len) {
inline static void Softmax(const double* input, double* output, int len) {
double wmax = input[0];
for (int i = 1; i < len; ++i) {
wmax = std::max(input[i], wmax);
......@@ -467,7 +619,7 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<
}
template<typename T1, typename T2>
inline void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
inline static void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
std::vector<std::pair<T1, T2>> arr;
for (size_t i = start; i < keys.size(); ++i) {
arr.emplace_back(keys[i], values[i]);
......@@ -537,12 +689,22 @@ inline static double AvoidInf(double x) {
}
}
template<class _Iter> inline
inline static float AvoidInf(float x) {
if (x >= 1e38) {
return 1e38f;
} else if (x <= -1e38) {
return -1e38f;
} else {
return x;
}
}
template<typename _Iter> inline
static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
return (0);
}
template<class _RanIt, class _Pr, class _VTRanIt> inline
template<typename _RanIt, typename _Pr, typename _VTRanIt> inline
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
size_t len = _Last - _First;
const size_t kMinInnerLen = 1024;
......@@ -589,14 +751,14 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
}
}
template<class _RanIt, class _Pr> inline
template<typename _RanIt, typename _Pr> inline
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
}
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
......@@ -627,7 +789,7 @@ inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, cons
// One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements.
template <typename T1, typename T2>
inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
T1 minw;
T1 maxw;
T1 sumw;
......@@ -669,8 +831,8 @@ inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
}
}
template<class T>
inline std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
template<typename T>
inline static std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
std::vector<uint32_t> ret;
for (int i = 0; i < n; ++i) {
int i1 = vals[i] / 32;
......@@ -683,8 +845,8 @@ inline std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
return ret;
}
template<class T>
inline bool FindInBitset(const uint32_t* bits, int n, T pos) {
template<typename T>
inline static bool FindInBitset(const uint32_t* bits, int n, T pos) {
int i1 = pos / 32;
if (i1 >= n) {
return false;
......@@ -702,6 +864,24 @@ inline static double GetDoubleUpperBound(double a) {
return std::nextafter(a, INFINITY);;
}
inline static size_t GetLine(const char* str) {
auto start = str;
while (*str != '\0' && *str != '\n' && *str != '\r') {
++str;
}
return str - start;
}
inline static const char* SkipNewLine(const char* str) {
if (*str == '\r') {
++str;
}
if (*str == '\n') {
++str;
}
return str;
}
} // namespace Common
} // namespace LightGBM
......
......@@ -148,6 +148,31 @@ public:
});
}
std::vector<char> ReadContent(size_t* out_len) {
std::vector<char> ret;
*out_len = 0;
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, filename_, "rb");
#else
file = fopen(filename_, "rb");
#endif
if (file == NULL) {
return ret;
}
const size_t buffer_size = 16 * 1024 * 1024;
auto buffer_read = std::vector<char>(buffer_size);
size_t read_cnt = 0;
do {
read_cnt = fread(buffer_read.data(), 1, buffer_size, file);
ret.insert(ret.end(), buffer_read.begin(), buffer_read.begin() + read_cnt);
*out_len += read_cnt;
} while (read_cnt > 0);
// close file
fclose(file);
return ret;
}
INDEX_T SampleFromFile(Random& random, INDEX_T sample_cnt, std::vector<std::string>* out_sampled_data) {
INDEX_T cur_sample_cnt = 0;
return ReadAllAndProcess(
......
PMML Generator
==============
The script pmml.py can be used to translate the LightGBM models, found in LightGBM_model.txt, to predictive model markup language (PMML). These models can then be imported by other analytics applications. The models that the language can describe includes decision trees. The specification of PMML can be found here at the Data Mining Group's [website](http://dmg.org/pmml/v4-3/GeneralStructure.html).
The old python convert script is removed due to it cannot support the new categorical features.
In order to generate pmml files do the following steps.
```
lightgbm config=train.conf
python pmml.py LightGBM_model.txt
```
The python script will create a file called **LightGBM_pmml.xml**. Inside the file you will find a `MiningModel` tag. In there you will find `TreeModel` tags. Each `TreeModel` tag contains the pmml translation of a decision tree inside the LightGBM_model.txt file. The model described by the **LightGBM_pmml.xml** file can be transferred to other analytics applications. For instance you can use the pmml file as an input to the jpmml-evaluator API. Follow the steps below to run a model described by **LightGBM_pmml.xml**.
##### Steps to Run jpmml-evaluator
1. Clone the repository
```
git clone https://github.com/jpmml/jpmml-evaluator.git
```
2. Build using maven
```
mvn clean install
```
3. Run the EvaluationExample class on the model file using the following command
```
java -cp example-1.3-SNAPSHOT.jar org.jpmml.evaluator.EvaluationExample --model LightGBM_pmml.xml --input input.csv --output output.csv
```
Note, in order to run the model on the input.csv file, the input.csv file must have the same number of columns as specified by the `DataDictionary` field in the pmml file. Also, the column headers inside the input.csv file must be the same as the column names specified by the `MiningSchema` field. Inside output.csv you will find all the columns inside the input.csv file plus a new column. In the new column you will find the scores calculated by processing each rows data on the model. More information about jpmml-evaluator can be found at its [github repository](https://github.com/jpmml/jpmml-evaluator).
\ No newline at end of file
Please move to https://github.com/jpmml/jpmml-lightgbm
# coding: utf-8
# pylint: disable = C0111, C0103
"""convert LightGBM model to pmml"""
from __future__ import absolute_import
from sys import argv
from itertools import count
def get_value_string(line):
return line[line.find('=') + 1:]
def get_array_strings(line):
return get_value_string(line).split()
def get_array_ints(line):
return [int(token) for token in get_array_strings(line)]
def get_field_name(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id] if is_child else prev_node_idx
return feature_names[split_feature[idx]]
def get_threshold(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id] if is_child else prev_node_idx
return threshold[idx]
def print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf):
if is_left_child:
op = 'equal' if decision_type[prev_node_idx] == 1 else 'lessOrEqual'
else:
op = 'notEqual' if decision_type[prev_node_idx] == 1 else 'greaterThan'
out_('\t' * (tab_len + 1) + ("<SimplePredicate field=\"{0}\" " + " operator=\"{1}\" value=\"{2}\" />").format(
get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf)))
def print_nodes_pmml(node_id, tab_len, is_left_child, prev_node_idx):
if node_id < 0:
node_id = ~node_id
score = leaf_value[node_id]
recordCount = leaf_count[node_id]
is_leaf = True
else:
score = internal_value[node_id]
recordCount = internal_count[node_id]
is_leaf = False
out_('\t' * tab_len + ("<Node id=\"{0}\" score=\"{1}\" " + " recordCount=\"{2}\">").format(
next(unique_id), score, recordCount))
print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf)
if not is_leaf:
print_nodes_pmml(left_child[node_id], tab_len + 1, True, node_id)
print_nodes_pmml(right_child[node_id], tab_len + 1, False, node_id)
out_('\t' * tab_len + "</Node>")
# print out the pmml for a decision tree
def print_pmml():
# specify the objective as function name and binarySplit for
# splitCharacteristic because each node has 2 children
out_("\t\t\t\t<TreeModel functionName=\"regression\" splitCharacteristic=\"binarySplit\">")
out_("\t\t\t\t\t<MiningSchema>")
# list each feature name as a mining field, and treat all outliers as is,
# unless specified
for feature in feature_names:
out_("\t\t\t\t\t\t<MiningField name=\"%s\"/>" % (feature))
out_("\t\t\t\t\t</MiningSchema>")
# begin printing out the decision tree
out_("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
next(unique_id), internal_value[0], internal_count[0]))
out_("\t\t\t\t\t\t<True/>")
print_nodes_pmml(left_child[0], 6, True, 0)
print_nodes_pmml(right_child[0], 6, False, 0)
out_("\t\t\t\t\t</Node>")
out_("\t\t\t\t</TreeModel>")
if len(argv) != 2:
raise ValueError('usage: pmml.py <input model file>')
# open the model file and then process it
with open(argv[1], 'r') as model_in:
# ignore first 6 and empty lines
model_content = iter([line for line in model_in.read().splitlines() if line][6:])
feature_names = get_array_strings(next(model_content))
feature_infos = get_array_strings(next(model_content))
segment_id = count(1)
with open('LightGBM_pmml.xml', 'w') as pmml_out:
def out_(string):
pmml_out.write(string + '\n')
out_(
"<PMML version=\"4.3\" \n" +
"\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
"\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" +
"\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\">")
out_("\t<Header copyright=\"Microsoft\">")
out_("\t\t<Application name=\"LightGBM\"/>")
out_("\t</Header>")
# print out data dictionary entries for each column
out_("\t<DataDictionary numberOfFields=\"%d\">" % len(feature_names))
# not adding any interval definition, all values are currently
# valid
for feature in feature_names:
out_("\t\t<DataField name=\"" + feature + "\" optype=\"continuous\" dataType=\"double\"/>")
out_("\t</DataDictionary>")
out_("\t<MiningModel functionName=\"regression\">")
out_("\t\t<MiningSchema>")
# list each feature name as a mining field, and treat all outliers
# as is, unless specified
for feature in feature_names:
out_("\t\t\t<MiningField name=\"%s\"/>" % (feature))
out_("\t\t</MiningSchema>")
out_("\t\t<Segmentation multipleModelMethod=\"sum\">")
# read each array that contains pertinent information for the pmml
# these arrays will be used to recreate the traverse the decision tree
while True:
tree_start = next(model_content, '')
if not tree_start.startswith('Tree'):
break
out_("\t\t\t<Segment id=\"%d\">" % next(segment_id))
out_("\t\t\t\t<True/>")
tree_no = tree_start[5:]
num_leaves = int(get_value_string(next(model_content)))
split_feature = get_array_ints(next(model_content))
split_gain = next(model_content) # unused
threshold = get_array_strings(next(model_content))
decision_type = get_array_ints(next(model_content))
left_child = get_array_ints(next(model_content))
right_child = get_array_ints(next(model_content))
leaf_parent = get_array_ints(next(model_content))
leaf_value = get_array_strings(next(model_content))
leaf_count = get_array_strings(next(model_content))
internal_value = get_array_strings(next(model_content))
internal_count = get_array_strings(next(model_content))
shrinkage = get_value_string(next(model_content))
has_categorical = get_value_string(next(model_content))
unique_id = count(1)
print_pmml()
out_("\t\t\t</Segment>")
out_("\t\t</Segmentation>")
out_("\t</MiningModel>")
out_("</PMML>")
syntax = "proto3";
package LightGBM;
message Model {
string name = 1;
uint32 num_class = 2;
uint32 num_tree_per_iteration = 3;
uint32 label_index = 4;
uint32 max_feature_idx = 5;
string objective = 6;
bool average_output = 7;
repeated string feature_names = 8;
repeated string feature_infos = 9;
message Tree {
uint32 num_leaves = 1;
uint32 num_cat = 2;
repeated uint32 split_feature = 3;
repeated double split_gain = 4;
repeated double threshold = 5;
repeated uint32 decision_type = 6;
repeated sint32 left_child = 7;
repeated sint32 right_child = 8;
repeated double leaf_value = 9;
repeated uint32 leaf_count = 10;
repeated double internal_value = 11;
repeated double internal_count = 12;
repeated sint32 cat_boundaries = 13;
repeated uint32 cat_threshold = 14;
double shrinkage = 15;
}
repeated Tree trees = 10;
}
......@@ -555,7 +555,7 @@ class _InnerPredictor(object):
class Dataset(object):
"""Dataset in LightGBM."""
def __init__(self, data, label=None, max_bin=None, reference=None,
def __init__(self, data, label=None, reference=None,
weight=None, group=None, init_score=None, silent=False,
feature_name='auto', categorical_feature='auto', params=None,
free_raw_data=True):
......@@ -568,9 +568,6 @@ class Dataset(object):
If string, it represents the path to txt file.
label : list, numpy 1-D array or None, optional (default=None)
Label of the data.
max_bin : int or None, optional (default=None)
Max number of discrete bins for features.
If None, default value from parameters of CLI-version will be used.
reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array or None, optional (default=None)
......@@ -597,7 +594,6 @@ class Dataset(object):
self.handle = None
self.data = data
self.label = label
self.max_bin = max_bin
self.reference = reference
self.weight = weight
self.group = group
......@@ -620,7 +616,7 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
self.handle = None
def _lazy_init(self, data, label=None, max_bin=None, reference=None,
def _lazy_init(self, data, label=None, reference=None,
weight=None, group=None, init_score=None, predictor=None,
silent=False, feature_name='auto',
categorical_feature='auto', params=None):
......@@ -640,12 +636,7 @@ class Dataset(object):
if key in args_names:
warnings.warn('{0} keyword has been found in `params` and will be ignored. '
'Please use {0} argument of the Dataset constructor to pass this parameter.'.format(key))
self.max_bin = max_bin
self.predictor = predictor
if self.max_bin is not None:
params["max_bin"] = self.max_bin
warnings.warn('The `max_bin` parameter is deprecated and will be removed in 2.0.12 version. '
'Please use `params` to pass this parameter.', LGBMDeprecationWarning)
if "verbosity" in params:
params.setdefault("verbose", params.pop("verbosity"))
if silent:
......@@ -821,7 +812,7 @@ class Dataset(object):
if self.reference is not None:
if self.used_indices is None:
# create valid
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference,
self._lazy_init(self.data, label=self.label, reference=self.reference,
weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor,
silent=self.silent, feature_name=self.feature_name, params=self.params)
else:
......@@ -839,7 +830,7 @@ class Dataset(object):
raise ValueError("Label should not be None.")
else:
# create train
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin,
self._lazy_init(self.data, label=self.label,
weight=self.weight, group=self.group, init_score=self.init_score,
predictor=self._predictor, silent=self.silent, feature_name=self.feature_name,
categorical_feature=self.categorical_feature, params=self.params)
......@@ -874,7 +865,7 @@ class Dataset(object):
self : Dataset
Returns self.
"""
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self,
ret = Dataset(data, label=label, reference=self,
weight=weight, group=group, init_score=init_score,
silent=silent, params=params, free_raw_data=self.free_raw_data)
ret._predictor = self._predictor
......
......@@ -133,7 +133,7 @@ class LGBMModel(_LGBMModelBase):
"""Implementation of the scikit-learn API for LightGBM."""
def __init__(self, boosting_type="gbdt", num_leaves=31, max_depth=-1,
learning_rate=0.1, n_estimators=10, max_bin=255,
learning_rate=0.1, n_estimators=10,
subsample_for_bin=200000, objective=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=1, colsample_bytree=1.,
......@@ -156,8 +156,6 @@ class LGBMModel(_LGBMModelBase):
Boosting learning rate.
n_estimators : int, optional (default=10)
Number of boosted trees to fit.
max_bin : int, optional (default=255)
Number of bucketed bins for feature values.
subsample_for_bin : int, optional (default=50000)
Number of samples for constructing bins.
objective : string, callable or None, optional (default=None)
......@@ -246,7 +244,6 @@ class LGBMModel(_LGBMModelBase):
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.max_bin = max_bin
self.subsample_for_bin = subsample_for_bin
self.min_split_gain = min_split_gain
self.min_child_weight = min_child_weight
......@@ -410,7 +407,7 @@ class LGBMModel(_LGBMModelBase):
self._n_features = X.shape[1]
def _construct_dataset(X, y, sample_weight, init_score, group, params):
ret = Dataset(X, label=y, max_bin=self.max_bin, weight=sample_weight, group=group, params=params)
ret = Dataset(X, label=y, weight=sample_weight, group=group, params=params)
ret.set_init_score(init_score)
return ret
......
......@@ -180,7 +180,6 @@ void Application::InitTrain() {
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
// create objective function
objective_fun_.reset(
......@@ -204,26 +203,7 @@ void Application::InitTrain() {
void Application::Train() {
Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
std::vector<std::string> model_formats = Common::Split(config_.io_config.model_format.c_str(), ',');
bool save_with_multiple_format = (model_formats.size() > 1);
for (auto model_format: model_formats) {
std::string save_file_name = config_.io_config.output_model;
if (save_with_multiple_format) {
// use suffix to distinguish different model format
save_file_name += "." + model_format;
}
if (model_format == std::string("text")) {
boosting_->SaveModelToFile(-1, save_file_name.c_str());
} else if (model_format == std::string("proto")) {
#ifdef USE_PROTO
boosting_->SaveModelToProto(-1, save_file_name.c_str());
#else
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
#endif // USE_PROTO
} else {
Log::Fatal("Unknown model format during saving: %s", model_format.c_str());
}
}
boosting_->SaveModelToFile(-1, config_.io_config.output_model.c_str());
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
......@@ -244,16 +224,13 @@ void Application::Predict() {
void Application::InitPredict() {
boosting_.reset(
Boosting::CreateBoosting("gbdt", config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
Boosting::CreateBoosting("gbdt", config_.io_config.input_model.c_str()));
Log::Info("Finished initializing prediction");
}
void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
Boosting::CreateBoosting(config_.boosting_type, config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
......
......@@ -12,34 +12,22 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return type;
}
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) {
bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
auto start_time = std::chrono::steady_clock::now();
if (boosting != nullptr) {
if (format == std::string("text")) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
}
if (!boosting->LoadModelFromString(str_buf.str())) {
return false;
}
} else if (format == std::string("proto")) {
#ifdef USE_PROTO
if (!boosting->LoadModelFromProto(filename)) {
size_t buffer_len = 0;
auto buffer = model_reader.ReadContent(&buffer_len);
if (!boosting->LoadModelFromString(buffer.data(), buffer_len)) {
return false;
}
#else
Log::Fatal("Please cmake with -DUSE_PROTO=ON to use protobuf.");
#endif // USE_PROTO
} else {
Log::Fatal("Unknown model format during loading: %s", format.c_str());
}
}
std::chrono::duration<double, std::milli> delta = (std::chrono::steady_clock::now() - start_time);
Log::Info("time for loading model: %f seconds", 1e-3*delta);
return true;
}
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) {
Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
if (filename == nullptr || filename[0] == '\0') {
if (type == std::string("gbdt")) {
return new GBDT();
......@@ -54,7 +42,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& f
}
} else {
std::unique_ptr<Boosting> ret;
if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
if (GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
if (type == std::string("gbdt")) {
ret.reset(new GBDT());
} else if (type == std::string("dart")) {
......@@ -66,7 +54,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& f
} else {
Log::Fatal("unknown boosting type %s", type.c_str());
}
LoadFileToBoosting(ret.get(), format, filename);
LoadFileToBoosting(ret.get(), filename);
} else {
Log::Fatal("unknown model format or submodel type in model file %s", filename);
}
......
......@@ -588,7 +588,7 @@ std::string GBDT::OutputMetric(int iter) {
<< " : " << scores[k];
Log::Info(tmp_buf.str().c_str());
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl;
msg_buf << tmp_buf.str() << '\n';
}
}
}
......@@ -608,7 +608,7 @@ std::string GBDT::OutputMetric(int iter) {
Log::Info(tmp_buf.str().c_str());
}
if (early_stopping_round_ > 0) {
msg_buf << tmp_buf.str() << std::endl;
msg_buf << tmp_buf.str() << '\n';
}
}
if (ret.empty() && early_stopping_round_ > 0) {
......
......@@ -241,25 +241,9 @@ public:
virtual std::string SaveModelToString(int num_iterations) const override;
/*!
* \brief Restore from a serialized string
* \brief Restore from a serialized buffer
*/
bool LoadModelFromString(const std::string& model_str) override;
#ifdef USE_PROTO
/*!
* \brief Save model with protobuf
* \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to
*/
void SaveModelToProto(int num_iteration, const char* filename) const override;
/*!
* \brief Restore from a serialized protobuf file
* \param filename Filename that want to restore from
* \return true if succeeded
*/
bool LoadModelFromProto(const char* filename) override;
#endif // USE_PROTO
bool LoadModelFromString(const char* buffer, size_t len) override;
/*!
* \brief Calculate feature importances
......
......@@ -10,19 +10,22 @@
namespace LightGBM {
const std::string kModelVersion = "v2";
std::string GBDT::DumpModel(int num_iteration) const {
std::stringstream str_buf;
str_buf << "{";
str_buf << "\"name\":\"" << SubModelName() << "\"," << std::endl;
str_buf << "\"num_class\":" << num_class_ << "," << std::endl;
str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << std::endl;
str_buf << "\"label_index\":" << label_idx_ << "," << std::endl;
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << std::endl;
str_buf << "\"name\":\"" << SubModelName() << "\"," << '\n';
str_buf << "\"version\":\"" << kModelVersion << "\"," << '\n';
str_buf << "\"num_class\":" << num_class_ << "," << '\n';
str_buf << "\"num_tree_per_iteration\":" << num_tree_per_iteration_ << "," << '\n';
str_buf << "\"label_index\":" << label_idx_ << "," << '\n';
str_buf << "\"max_feature_idx\":" << max_feature_idx_ << "," << '\n';
str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names_, "\",\"") << "\"],"
<< std::endl;
<< '\n';
str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size());
......@@ -38,9 +41,9 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf << models_[i]->ToJSON();
str_buf << "}";
}
str_buf << "]" << std::endl;
str_buf << "]" << '\n';
str_buf << "}" << std::endl;
str_buf << "}" << '\n';
return str_buf.str();
}
......@@ -48,18 +51,18 @@ std::string GBDT::DumpModel(int num_iteration) const {
std::string GBDT::ModelToIfElse(int num_iteration) const {
std::stringstream str_buf;
str_buf << "#include \"gbdt.h\"" << std::endl;
str_buf << "#include <LightGBM/utils/common.h>" << std::endl;
str_buf << "#include <LightGBM/objective_function.h>" << std::endl;
str_buf << "#include <LightGBM/metric.h>" << std::endl;
str_buf << "#include <LightGBM/prediction_early_stop.h>" << std::endl;
str_buf << "#include <ctime>" << std::endl;
str_buf << "#include <sstream>" << std::endl;
str_buf << "#include <chrono>" << std::endl;
str_buf << "#include <string>" << std::endl;
str_buf << "#include <vector>" << std::endl;
str_buf << "#include <utility>" << std::endl;
str_buf << "namespace LightGBM {" << std::endl;
str_buf << "#include \"gbdt.h\"" << '\n';
str_buf << "#include <LightGBM/utils/common.h>" << '\n';
str_buf << "#include <LightGBM/objective_function.h>" << '\n';
str_buf << "#include <LightGBM/metric.h>" << '\n';
str_buf << "#include <LightGBM/prediction_early_stop.h>" << '\n';
str_buf << "#include <ctime>" << '\n';
str_buf << "#include <sstream>" << '\n';
str_buf << "#include <chrono>" << '\n';
str_buf << "#include <string>" << '\n';
str_buf << "#include <vector>" << '\n';
str_buf << "#include <utility>" << '\n';
str_buf << "namespace LightGBM {" << '\n';
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
......@@ -68,7 +71,7 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
// PredictRaw
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, false) << std::endl;
str_buf << models_[i]->ToIfElse(i, false) << '\n';
}
str_buf << "double (*PredictTreePtr[])(const double*) = { ";
......@@ -78,28 +81,28 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
}
str_buf << "PredictTree" << i;
}
str_buf << " };" << std::endl << std::endl;
str_buf << " };" << '\n' << '\n';
std::stringstream pred_str_buf;
pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << std::endl;
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t\t" << "++early_stop_round_counter;" << std::endl;
pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
pred_str_buf << "\t\t\t\t" << "return;" << std::endl;
pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
pred_str_buf << "\t\t" << "}" << std::endl;
pred_str_buf << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
pred_str_buf << "\t" << "int early_stop_round_counter = 0;" << '\n';
pred_str_buf << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
pred_str_buf << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
pred_str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
pred_str_buf << "\t\t\t" << "output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
pred_str_buf << "\t\t" << "}" << '\n';
pred_str_buf << "\t\t" << "++early_stop_round_counter;" << '\n';
pred_str_buf << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
pred_str_buf << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
pred_str_buf << "\t\t\t\t" << "return;" << '\n';
pred_str_buf << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
pred_str_buf << "\t\t" << "}" << '\n';
pred_str_buf << "\t" << "}" << '\n';
str_buf << "void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
str_buf << pred_str_buf.str();
str_buf << "}" << std::endl;
str_buf << std::endl;
str_buf << "}" << '\n';
str_buf << '\n';
// PredictRawByMap
str_buf << "double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { ";
......@@ -109,61 +112,61 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
}
str_buf << "PredictTree" << i << "ByMap";
}
str_buf << " };" << std::endl << std::endl;
str_buf << " };" << '\n' << '\n';
std::stringstream pred_str_buf_map;
pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << std::endl;
pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << std::endl;
pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << std::endl;
pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << std::endl;
pred_str_buf_map << "\t\t" << "}" << std::endl;
pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << std::endl;
pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << std::endl;
pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << std::endl;
pred_str_buf_map << "\t\t\t\t" << "return;" << std::endl;
pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << std::endl;
pred_str_buf_map << "\t\t" << "}" << std::endl;
pred_str_buf_map << "\t" << "}" << std::endl;
str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
pred_str_buf_map << "\t" << "int early_stop_round_counter = 0;" << '\n';
pred_str_buf_map << "\t" << "std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);" << '\n';
pred_str_buf_map << "\t" << "for (int i = 0; i < num_iteration_for_pred_; ++i) {" << '\n';
pred_str_buf_map << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
pred_str_buf_map << "\t\t\t" << "output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);" << '\n';
pred_str_buf_map << "\t\t" << "}" << '\n';
pred_str_buf_map << "\t\t" << "++early_stop_round_counter;" << '\n';
pred_str_buf_map << "\t\t" << "if (early_stop->round_period == early_stop_round_counter) {" << '\n';
pred_str_buf_map << "\t\t\t" << "if (early_stop->callback_function(output, num_tree_per_iteration_))" << '\n';
pred_str_buf_map << "\t\t\t\t" << "return;" << '\n';
pred_str_buf_map << "\t\t\t" << "early_stop_round_counter = 0;" << '\n';
pred_str_buf_map << "\t\t" << "}" << '\n';
pred_str_buf_map << "\t" << "}" << '\n';
str_buf << "void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
str_buf << pred_str_buf_map.str();
str_buf << "}" << std::endl;
str_buf << std::endl;
str_buf << "}" << '\n';
str_buf << '\n';
// Predict
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
str_buf << "\t" << "if (average_output_) {" << std::endl;
str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << std::endl;
str_buf << "\t\t" << "}" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << std::endl;
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
str_buf << "\t" << "PredictRaw(features, output, early_stop);" << '\n';
str_buf << "\t" << "if (average_output_) {" << '\n';
str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
str_buf << "\t\t" << "}" << '\n';
str_buf << "\t" << "}" << '\n';
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << '\n';
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
str_buf << "\t" << "}" << '\n';
str_buf << "}" << '\n';
str_buf << '\n';
// PredictByMap
str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << std::endl;
str_buf << "\t" << "if (average_output_) {" << std::endl;
str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << std::endl;
str_buf << "\t\t" << "}" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << std::endl;
str_buf << "void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {" << '\n';
str_buf << "\t" << "PredictRawByMap(features, output, early_stop);" << '\n';
str_buf << "\t" << "if (average_output_) {" << '\n';
str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << '\n';
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << '\n';
str_buf << "\t\t" << "}" << '\n';
str_buf << "\t" << "}" << '\n';
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << '\n';
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << '\n';
str_buf << "\t" << "}" << '\n';
str_buf << "}" << '\n';
str_buf << '\n';
// PredictLeafIndex
for (int i = 0; i < num_used_model; ++i) {
str_buf << models_[i]->ToIfElse(i, true) << std::endl;
str_buf << models_[i]->ToIfElse(i, true) << '\n';
}
str_buf << "double (*PredictTreeLeafPtr[])(const double*) = { ";
......@@ -173,14 +176,14 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
}
str_buf << "PredictTree" << i << "Leaf";
}
str_buf << " };" << std::endl << std::endl;
str_buf << " };" << '\n' << '\n';
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << "void GBDT::PredictLeafIndex(const double* features, double *output) const {" << '\n';
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafPtr[i])(features);" << '\n';
str_buf << "\t" << "}" << '\n';
str_buf << "}" << '\n';
//PredictLeafIndexByMap
str_buf << "double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { ";
......@@ -190,16 +193,16 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
}
str_buf << "PredictTree" << i << "LeafByMap";
}
str_buf << " };" << std::endl << std::endl;
str_buf << " };" << '\n' << '\n';
str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << std::endl;
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << std::endl;
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << std::endl;
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
str_buf << "void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {" << '\n';
str_buf << "\t" << "int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;" << '\n';
str_buf << "\t" << "for (int i = 0; i < total_tree; ++i) {" << '\n';
str_buf << "\t\t" << "output[i] = (*PredictTreeLeafByMapPtr[i])(features);" << '\n';
str_buf << "\t" << "}" << '\n';
str_buf << "}" << '\n';
str_buf << "} // namespace LightGBM" << std::endl;
str_buf << "} // namespace LightGBM" << '\n';
return str_buf.str();
}
......@@ -212,12 +215,12 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
std::string origin((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
output_file.open(filename);
output_file << "#define USE_HARD_CODE 0" << std::endl;
output_file << "#ifndef USE_HARD_CODE" << std::endl;
output_file << origin << std::endl;
output_file << "#else" << std::endl;
output_file << "#define USE_HARD_CODE 0" << '\n';
output_file << "#ifndef USE_HARD_CODE" << '\n';
output_file << origin << '\n';
output_file << "#else" << '\n';
output_file << ModelToIfElse(num_iteration);
output_file << "#endif" << std::endl;
output_file << "#endif" << '\n';
} else {
output_file.open(filename);
output_file << ModelToIfElse(num_iteration);
......@@ -233,40 +236,52 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
std::stringstream ss;
// output model type
ss << SubModelName() << std::endl;
ss << SubModelName() << '\n';
ss << "version=" << kModelVersion << '\n';
// output number of class
ss << "num_class=" << num_class_ << std::endl;
ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << std::endl;
ss << "num_class=" << num_class_ << '\n';
ss << "num_tree_per_iteration=" << num_tree_per_iteration_ << '\n';
// output label index
ss << "label_index=" << label_idx_ << std::endl;
ss << "label_index=" << label_idx_ << '\n';
// output max_feature_idx
ss << "max_feature_idx=" << max_feature_idx_ << std::endl;
ss << "max_feature_idx=" << max_feature_idx_ << '\n';
// output objective
if (objective_function_ != nullptr) {
ss << "objective=" << objective_function_->ToString() << std::endl;
ss << "objective=" << objective_function_->ToString() << '\n';
}
if (average_output_) {
ss << "average_output" << std::endl;
ss << "average_output" << '\n';
}
ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;
ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
ss << std::endl;
int num_used_model = static_cast<int>(models_.size());
if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model);
}
std::vector<std::string> tree_strs(num_used_model);
std::vector<size_t> tree_sizes(num_used_model);
// output tree models
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_used_model; ++i) {
tree_strs[i] = "Tree=" + std::to_string(i) + '\n';
tree_strs[i] += models_[i]->ToString() + '\n';
tree_sizes[i] = tree_strs[i].size();
}
ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
ss << '\n';
for (int i = 0; i < num_used_model; ++i) {
ss << "Tree=" << i << std::endl;
ss << models_[i]->ToString() << std::endl;
ss << tree_strs[i];
tree_strs[i].clear();
}
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0);
// store the importance first
std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) {
......@@ -281,72 +296,93 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
ss << std::endl << "feature importances:" << std::endl;
ss << '\n' << "feature importances:" << '\n';
for (size_t i = 0; i < pairs.size(); ++i) {
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << std::endl;
ss << pairs[i].second << "=" << std::to_string(pairs[i].first) << '\n';
}
return ss.str();
}
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
/*! \brief File to write models */
std::ofstream output_file;
output_file.open(filename);
output_file << SaveModelToString(num_iteration);
output_file.open(filename, std::ios::out | std::ios::binary);
std::string str_to_write = SaveModelToString(num_iteration);
output_file.write(str_to_write.c_str(), str_to_write.size());
output_file.close();
return (bool)output_file;
}
bool GBDT::LoadModelFromString(const std::string& model_str) {
bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
// use serialized string to restore this object
models_.clear();
std::vector<std::string> lines = Common::SplitLines(model_str.c_str());
auto c_str = buffer;
auto p = c_str;
auto end = p + len;
std::unordered_map<std::string, std::string> key_vals;
while (p < end) {
auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) {
if (!Common::StartsWith(cur_line, "Tree=")) {
auto strs = Common::Split(cur_line.c_str(), '=');
if (strs.size() == 1) {
key_vals[strs[0]] = "";
}
else if (strs.size() == 2) {
key_vals[strs[0]] = strs[1];
}
else if (strs.size() > 2) {
Log::Fatal("Wrong line at model file: %s", cur_line.c_str());
}
}
else {
break;
}
}
p += line_len;
p = Common::SkipNewLine(p);
}
// get number of classes
auto line = Common::FindFromLines(lines, "num_class=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_class_);
if (key_vals.count("num_class")) {
Common::Atoi(key_vals["num_class"].c_str(), &num_class_);
} else {
Log::Fatal("Model file doesn't specify the number of classes");
return false;
}
line = Common::FindFromLines(lines, "num_tree_per_iteration=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &num_tree_per_iteration_);
if (key_vals.count("num_tree_per_iteration")) {
Common::Atoi(key_vals["num_tree_per_iteration"].c_str(), &num_tree_per_iteration_);
} else {
num_tree_per_iteration_ = num_class_;
}
// get index of label
line = Common::FindFromLines(lines, "label_index=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &label_idx_);
if (key_vals.count("label_index")) {
Common::Atoi(key_vals["label_index"].c_str(), &label_idx_);
} else {
Log::Fatal("Model file doesn't specify the label index");
return false;
}
// get max_feature_idx first
line = Common::FindFromLines(lines, "max_feature_idx=");
if (line.size() > 0) {
Common::Atoi(Common::Split(line.c_str(), '=')[1].c_str(), &max_feature_idx_);
if (key_vals.count("max_feature_idx")) {
Common::Atoi(key_vals["max_feature_idx"].c_str(), &max_feature_idx_);
} else {
Log::Fatal("Model file doesn't specify max_feature_idx");
return false;
}
// get average_output
line = Common::FindFromLines(lines, "average_output");
if (line.size() > 0) {
if (key_vals.count("average_output")) {
average_output_ = true;
}
// get feature names
line = Common::FindFromLines(lines, "feature_names=");
if (line.size() > 0) {
feature_names_ = Common::Split(line.substr(std::strlen("feature_names=")).c_str(), ' ');
if (key_vals.count("feature_names")) {
feature_names_ = Common::Split(key_vals["feature_names"].c_str(), ' ');
if (feature_names_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_names");
return false;
......@@ -356,9 +392,8 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
return false;
}
line = Common::FindFromLines(lines, "feature_infos=");
if (line.size() > 0) {
feature_infos_ = Common::Split(line.substr(std::strlen("feature_infos=")).c_str(), ' ');
if (key_vals.count("feature_infos")) {
feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of feature_infos");
return false;
......@@ -368,29 +403,57 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
return false;
}
line = Common::FindFromLines(lines, "objective=");
if (line.size() > 0) {
auto str = Common::Split(line.c_str(), '=')[1];
if (key_vals.count("objective")) {
auto str = key_vals["objective"];
loaded_objective_.reset(ObjectiveFunction::CreateObjectiveFunction(str));
objective_function_ = loaded_objective_.get();
}
// get tree models
size_t i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("Tree=");
if (find_pos != std::string::npos) {
++i;
int start = static_cast<int>(i);
while (i < lines.size() && lines[i].find("Tree=") == std::string::npos) { ++i; }
int end = static_cast<int>(i);
std::string tree_str = Common::Join<std::string>(lines, start, end, "\n");
models_.emplace_back(new Tree(tree_str));
if (!key_vals.count("tree_sizes")) {
while (p < end) {
auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) {
if (Common::StartsWith(cur_line, "Tree=")) {
p += line_len;
p = Common::SkipNewLine(p);
size_t used_len = 0;
models_.emplace_back(new Tree(p, &used_len));
p += used_len;
}
else {
break;
}
}
p = Common::SkipNewLine(p);
}
} else {
std::vector<size_t> tree_sizes = Common::StringToArray<size_t>(key_vals["tree_sizes"].c_str(), ' ');
std::vector<size_t> tree_boundries(tree_sizes.size() + 1, 0);
int num_trees = static_cast<int>(tree_sizes.size());
for (int i = 0; i < num_trees; ++i) {
tree_boundries[i + 1] = tree_boundries[i] + tree_sizes[i];
models_.emplace_back(nullptr);
}
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_trees; ++i) {
OMP_LOOP_EX_BEGIN();
auto cur_p = p + tree_boundries[i];
auto line_len = Common::GetLine(cur_p);
std::string cur_line(cur_p, line_len);
if (Common::StartsWith(cur_line, "Tree=")) {
cur_p += line_len;
cur_p = Common::SkipNewLine(cur_p);
size_t used_len = 0;
models_[i].reset(new Tree(cur_p, &used_len));
} else {
++i;
Log::Fatal("Model format error, expect a tree here. met %s", cur_line.c_str());
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
}
Log::Info("Finished loading %d models", models_.size());
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
num_init_iteration_ = num_iteration_for_pred_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment