Unverified Commit 87d46489 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

feature importance type in saved model file (#3220)



* feature importance type in saved model file

* fix nullptr

* fixed formatting

* fix python/R

* Update src/c_api.cpp

* Apply suggestions from code review
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* fix c_api test

* fix swig

* minor docs improvements and added defines for importance types
Co-authored-by: default avatarStrikerRUS <nekit94-12@hotmail.com>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 7b8b5151
...@@ -424,7 +424,7 @@ Booster <- R6::R6Class( ...@@ -424,7 +424,7 @@ Booster <- R6::R6Class(
}, },
# Save model # Save model
save_model = function(filename, num_iteration = NULL) { save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
...@@ -437,6 +437,7 @@ Booster <- R6::R6Class( ...@@ -437,6 +437,7 @@ Booster <- R6::R6Class(
, ret = NULL , ret = NULL
, private$handle , private$handle
, as.integer(num_iteration) , as.integer(num_iteration)
, as.integer(feature_importance_type)
, lgb.c_str(filename) , lgb.c_str(filename)
) )
...@@ -445,7 +446,7 @@ Booster <- R6::R6Class( ...@@ -445,7 +446,7 @@ Booster <- R6::R6Class(
}, },
# Save model to string # Save model to string
save_model_to_string = function(num_iteration = NULL) { save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
...@@ -457,12 +458,13 @@ Booster <- R6::R6Class( ...@@ -457,12 +458,13 @@ Booster <- R6::R6Class(
"LGBM_BoosterSaveModelToString_R" "LGBM_BoosterSaveModelToString_R"
, private$handle , private$handle
, as.integer(num_iteration) , as.integer(num_iteration)
, as.integer(feature_importance_type)
)) ))
}, },
# Dump model in memory # Dump model in memory
dump_model = function(num_iteration = NULL) { dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {
# Check if number of iteration is non existent # Check if number of iteration is non existent
if (is.null(num_iteration)) { if (is.null(num_iteration)) {
...@@ -474,6 +476,7 @@ Booster <- R6::R6Class( ...@@ -474,6 +476,7 @@ Booster <- R6::R6Class(
"LGBM_BoosterDumpModel_R" "LGBM_BoosterDumpModel_R"
, private$handle , private$handle
, as.integer(num_iteration) , as.integer(num_iteration)
, as.integer(feature_importance_type)
) )
}, },
......
...@@ -632,15 +632,17 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle, ...@@ -632,15 +632,17 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle, LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename, LGBM_SE filename,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename))); CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_CHAR_PTR(filename)));
R_API_END(); R_API_END();
} }
LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle, LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len, LGBM_SE buffer_len,
LGBM_SE actual_len, LGBM_SE actual_len,
LGBM_SE out_str, LGBM_SE out_str,
...@@ -648,13 +650,14 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle, ...@@ -648,13 +650,14 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
R_API_BEGIN(); R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len)); std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len)); EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END(); R_API_END();
} }
LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle, LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len, LGBM_SE buffer_len,
LGBM_SE actual_len, LGBM_SE actual_len,
LGBM_SE out_str, LGBM_SE out_str,
...@@ -662,7 +665,7 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle, ...@@ -662,7 +665,7 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
R_API_BEGIN(); R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len)); std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len)); EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END(); R_API_END();
} }
...@@ -707,9 +710,9 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -707,9 +710,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8}, {"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14}, {"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11}, {"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4}, {"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6}, {"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6}, {"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7},
{NULL, NULL, 0} {NULL, NULL, 0}
}; };
......
...@@ -590,6 +590,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R( ...@@ -590,6 +590,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R( LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename, LGBM_SE filename,
LGBM_SE call_state LGBM_SE call_state
); );
...@@ -604,6 +605,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R( ...@@ -604,6 +605,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R( LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len, LGBM_SE buffer_len,
LGBM_SE actual_len, LGBM_SE actual_len,
LGBM_SE out_str, LGBM_SE out_str,
...@@ -620,6 +622,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R( ...@@ -620,6 +622,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterDumpModel_R( LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterDumpModel_R(
LGBM_SE handle, LGBM_SE handle,
LGBM_SE num_iteration, LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len, LGBM_SE buffer_len,
LGBM_SE actual_len, LGBM_SE actual_len,
LGBM_SE out_str, LGBM_SE out_str,
......
...@@ -574,6 +574,14 @@ Learning Control Parameters ...@@ -574,6 +574,14 @@ Learning Control Parameters
- **Note**: can be used only in CLI version - **Note**: can be used only in CLI version
- ``saved_feature_importance_type`` :raw-html:`<a id="saved_feature_importance_type" title="Permalink to this parameter" href="#saved_feature_importance_type">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
- the feature importance type in the saved model file
- ``0``: count-based feature importance (numbers of splits are counted); ``1``: gain-based feature importance (values of gain are counted)
- **Note**: can be used only in CLI version
- ``snapshot_freq`` :raw-html:`<a id="snapshot_freq" title="Permalink to this parameter" href="#snapshot_freq">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int, aliases: ``save_period`` - ``snapshot_freq`` :raw-html:`<a id="snapshot_freq" title="Permalink to this parameter" href="#snapshot_freq">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int, aliases: ``save_period``
- frequency of saving model file snapshot - frequency of saving model file snapshot
......
...@@ -176,9 +176,10 @@ class LIGHTGBM_EXPORT Boosting { ...@@ -176,9 +176,10 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Dump model to json format string * \brief Dump model to json format string
* \param start_iteration The model will be saved start from * \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all * \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Json format string of model * \return Json format string of model
*/ */
virtual std::string DumpModel(int start_iteration, int num_iteration) const = 0; virtual std::string DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const = 0;
/*! /*!
* \brief Translate model to if-else statement * \brief Translate model to if-else statement
...@@ -199,19 +200,20 @@ class LIGHTGBM_EXPORT Boosting { ...@@ -199,19 +200,20 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Save model to file * \brief Save model to file
* \param start_iteration The model will be saved start from * \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not * \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \param filename Filename that want to save to * \param filename Filename that want to save to
* \return true if succeeded * \return true if succeeded
*/ */
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const = 0; virtual bool SaveModelToFile(int start_iteration, int num_iterations, int feature_importance_type, const char* filename) const = 0;
/*! /*!
* \brief Save model to string * \brief Save model to string
* \param start_iteration The model will be saved start from * \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Non-empty string if succeeded * \return Non-empty string if succeeded
*/ */
virtual std::string SaveModelToString(int start_iteration, int num_iterations) const = 0; virtual std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
......
...@@ -36,6 +36,9 @@ typedef void* BoosterHandle; /*!< \brief Handle of booster. */ ...@@ -36,6 +36,9 @@ typedef void* BoosterHandle; /*!< \brief Handle of booster. */
#define C_API_MATRIX_TYPE_CSR (0) /*!< \brief CSR sparse matrix type. */ #define C_API_MATRIX_TYPE_CSR (0) /*!< \brief CSR sparse matrix type. */
#define C_API_MATRIX_TYPE_CSC (1) /*!< \brief CSC sparse matrix type. */ #define C_API_MATRIX_TYPE_CSC (1) /*!< \brief CSC sparse matrix type. */
#define C_API_FEATURE_IMPORTANCE_SPLIT (0) /*!< \brief Split type of feature importance. */
#define C_API_FEATURE_IMPORTANCE_GAIN (1) /*!< \brief Gain type of feature importance. */
/*! /*!
* \brief Get string message of the last error. * \brief Get string message of the last error.
* \return Error information * \return Error information
...@@ -996,12 +999,14 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -996,12 +999,14 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
* \param handle Handle of booster * \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved * \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all * \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param filename The name of the file * \param filename The name of the file
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
const char* filename); const char* filename);
/*! /*!
...@@ -1009,6 +1014,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -1009,6 +1014,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
* \param handle Handle of booster * \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved * \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all * \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer * \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length * \param[out] out_len Actual output length
* \param[out] out_str String of model, should pre-allocate memory * \param[out] out_str String of model, should pre-allocate memory
...@@ -1017,6 +1023,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -1017,6 +1023,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
char* out_str); char* out_str);
...@@ -1026,6 +1033,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1026,6 +1033,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \param handle Handle of booster * \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be dumped * \param start_iteration Start index of the iteration that should be dumped
* \param num_iteration Index of the iteration that should be dumped, <= 0 means dump all * \param num_iteration Index of the iteration that should be dumped, <= 0 means dump all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer * \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length * \param[out] out_len Actual output length
* \param[out] out_str JSON format string of model, should pre-allocate memory * \param[out] out_str JSON format string of model, should pre-allocate memory
...@@ -1034,6 +1042,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1034,6 +1042,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
char* out_str); char* out_str);
...@@ -1069,8 +1078,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, ...@@ -1069,8 +1078,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
* \param handle Handle of booster * \param handle Handle of booster
* \param num_iteration Number of iterations for which feature importance is calculated, <= 0 means use all * \param num_iteration Number of iterations for which feature importance is calculated, <= 0 means use all
* \param importance_type Method of importance calculation: * \param importance_type Method of importance calculation:
* - 0 for split, result contains numbers of times the feature is used in a model; * - ``C_API_FEATURE_IMPORTANCE_SPLIT``: result contains numbers of times the feature is used in a model;
* - 1 for gain, result contains total gains of splits which use the feature * - ``C_API_FEATURE_IMPORTANCE_GAIN``: result contains total gains of splits which use the feature
* \param[out] out_results Result array with feature importance * \param[out] out_results Result array with feature importance
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
......
...@@ -532,6 +532,11 @@ struct Config { ...@@ -532,6 +532,11 @@ struct Config {
// desc = **Note**: can be used only in CLI version // desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
// desc = the feature importance type in the saved model file
// desc = ``0``: count-based feature importance (numbers of splits are counted); ``1``: gain-based feature importance (values of gain are counted)
// desc = **Note**: can be used only in CLI version
int saved_feature_importance_type = 0;
// [no-save] // [no-save]
// alias = save_period // alias = save_period
// desc = frequency of saving model file snapshot // desc = frequency of saving model file snapshot
......
...@@ -284,12 +284,20 @@ C_API_PREDICT_CONTRIB = 3 ...@@ -284,12 +284,20 @@ C_API_PREDICT_CONTRIB = 3
C_API_MATRIX_TYPE_CSR = 0 C_API_MATRIX_TYPE_CSR = 0
C_API_MATRIX_TYPE_CSC = 1 C_API_MATRIX_TYPE_CSC = 1
"""Macro definition of feature importance type"""
C_API_FEATURE_IMPORTANCE_SPLIT = 0
C_API_FEATURE_IMPORTANCE_GAIN = 1
"""Data type of data field""" """Data type of data field"""
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32, FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
"weight": C_API_DTYPE_FLOAT32, "weight": C_API_DTYPE_FLOAT32,
"init_score": C_API_DTYPE_FLOAT64, "init_score": C_API_DTYPE_FLOAT64,
"group": C_API_DTYPE_INT32} "group": C_API_DTYPE_INT32}
"""String name to int feature importance type mapper"""
FEATURE_IMPORTANCE_TYPE_MAPPER = {"split": C_API_FEATURE_IMPORTANCE_SPLIT,
"gain": C_API_FEATURE_IMPORTANCE_GAIN}
def convert_from_sliced_object(data): def convert_from_sliced_object(data):
"""Fix the memory of multi-dimensional sliced object.""" """Fix the memory of multi-dimensional sliced object."""
...@@ -2600,7 +2608,7 @@ class Booster(object): ...@@ -2600,7 +2608,7 @@ class Booster(object):
return [item for i in range_(1, self.__num_dataset) return [item for i in range_(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)] for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
def save_model(self, filename, num_iteration=None, start_iteration=0): def save_model(self, filename, num_iteration=None, start_iteration=0, importance_type='split'):
"""Save Booster to file. """Save Booster to file.
Parameters Parameters
...@@ -2613,6 +2621,10 @@ class Booster(object): ...@@ -2613,6 +2621,10 @@ class Booster(object):
If <= 0, all iterations are saved. If <= 0, all iterations are saved.
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
Start index of the iteration that should be saved. Start index of the iteration that should be saved.
importance_type : string, optional (default="split")
What type of feature importance should be saved.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns Returns
------- -------
...@@ -2621,10 +2633,12 @@ class Booster(object): ...@@ -2621,10 +2633,12 @@ class Booster(object):
""" """
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
_safe_call(_LIB.LGBM_BoosterSaveModel( _safe_call(_LIB.LGBM_BoosterSaveModel(
self.handle, self.handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
c_str(filename))) c_str(filename)))
_dump_pandas_categorical(self.pandas_categorical, filename) _dump_pandas_categorical(self.pandas_categorical, filename)
return self return self
...@@ -2685,7 +2699,7 @@ class Booster(object): ...@@ -2685,7 +2699,7 @@ class Booster(object):
self.pandas_categorical = _load_pandas_categorical(model_str=model_str) self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
return self return self
def model_to_string(self, num_iteration=None, start_iteration=0): def model_to_string(self, num_iteration=None, start_iteration=0, importance_type='split'):
"""Save Booster to string. """Save Booster to string.
Parameters Parameters
...@@ -2696,6 +2710,10 @@ class Booster(object): ...@@ -2696,6 +2710,10 @@ class Booster(object):
If <= 0, all iterations are saved. If <= 0, all iterations are saved.
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
Start index of the iteration that should be saved. Start index of the iteration that should be saved.
importance_type : string, optional (default="split")
What type of feature importance should be saved.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns Returns
------- -------
...@@ -2704,6 +2722,7 @@ class Booster(object): ...@@ -2704,6 +2722,7 @@ class Booster(object):
""" """
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
buffer_len = 1 << 20 buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
...@@ -2712,6 +2731,7 @@ class Booster(object): ...@@ -2712,6 +2731,7 @@ class Booster(object):
self.handle, self.handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(buffer_len), ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
...@@ -2724,6 +2744,7 @@ class Booster(object): ...@@ -2724,6 +2744,7 @@ class Booster(object):
self.handle, self.handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
...@@ -2731,7 +2752,7 @@ class Booster(object): ...@@ -2731,7 +2752,7 @@ class Booster(object):
ret += _dump_pandas_categorical(self.pandas_categorical) ret += _dump_pandas_categorical(self.pandas_categorical)
return ret return ret
def dump_model(self, num_iteration=None, start_iteration=0): def dump_model(self, num_iteration=None, start_iteration=0, importance_type='split'):
"""Dump Booster to JSON format. """Dump Booster to JSON format.
Parameters Parameters
...@@ -2742,6 +2763,10 @@ class Booster(object): ...@@ -2742,6 +2763,10 @@ class Booster(object):
If <= 0, all iterations are dumped. If <= 0, all iterations are dumped.
start_iteration : int, optional (default=0) start_iteration : int, optional (default=0)
Start index of the iteration that should be dumped. Start index of the iteration that should be dumped.
importance_type : string, optional (default="split")
What type of feature importance should be dumped.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
Returns Returns
------- -------
...@@ -2750,6 +2775,7 @@ class Booster(object): ...@@ -2750,6 +2775,7 @@ class Booster(object):
""" """
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
buffer_len = 1 << 20 buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
...@@ -2758,6 +2784,7 @@ class Booster(object): ...@@ -2758,6 +2784,7 @@ class Booster(object):
self.handle, self.handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(buffer_len), ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
...@@ -2770,6 +2797,7 @@ class Booster(object): ...@@ -2770,6 +2797,7 @@ class Booster(object):
self.handle, self.handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
...@@ -2969,12 +2997,7 @@ class Booster(object): ...@@ -2969,12 +2997,7 @@ class Booster(object):
""" """
if iteration is None: if iteration is None:
iteration = self.best_iteration iteration = self.best_iteration
if importance_type == "split": importance_type_int = FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
importance_type_int = 0
elif importance_type == "gain":
importance_type_int = 1
else:
importance_type_int = -1
result = np.zeros(self.num_feature(), dtype=np.float64) result = np.zeros(self.num_feature(), dtype=np.float64)
_safe_call(_LIB.LGBM_BoosterFeatureImportance( _safe_call(_LIB.LGBM_BoosterFeatureImportance(
self.handle, self.handle,
......
...@@ -201,7 +201,8 @@ void Application::InitTrain() { ...@@ -201,7 +201,8 @@ void Application::InitTrain() {
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
boosting_->Train(config_.snapshot_freq, config_.output_model); boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(0, -1, config_.output_model.c_str()); boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
// convert model to if-else statement code // convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) { if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
...@@ -233,7 +234,8 @@ void Application::Predict() { ...@@ -233,7 +234,8 @@ void Application::Predict() {
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf); boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(0, -1, config_.output_model.c_str()); boosting_->SaveModelToFile(0, -1, config_.saved_feature_importance_type,
config_.output_model.c_str());
Log::Info("Finished RefitTree"); Log::Info("Finished RefitTree");
} else { } else {
// create predictor // create predictor
......
...@@ -258,7 +258,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { ...@@ -258,7 +258,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
if (snapshot_freq > 0 if (snapshot_freq > 0
&& (iter + 1) % snapshot_freq == 0) { && (iter + 1) % snapshot_freq == 0) {
std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1); std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
SaveModelToFile(0, -1, snapshot_out.c_str()); SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str());
} }
} }
} }
......
...@@ -249,9 +249,11 @@ class GBDT : public GBDTBase { ...@@ -249,9 +249,11 @@ class GBDT : public GBDTBase {
* \brief Dump model to json format string * \brief Dump model to json format string
* \param start_iteration The model will be saved start from * \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all * \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Json format string of model * \return Json format string of model
*/ */
std::string DumpModel(int start_iteration, int num_iteration) const override; std::string DumpModel(int start_iteration, int num_iteration,
int feature_importance_type) const override;
/*! /*!
* \brief Translate model to if-else statement * \brief Translate model to if-else statement
...@@ -272,18 +274,22 @@ class GBDT : public GBDTBase { ...@@ -272,18 +274,22 @@ class GBDT : public GBDTBase {
* \brief Save model to file * \brief Save model to file
* \param start_iteration The model will be saved start from * \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \param filename Filename that want to save to * \param filename Filename that want to save to
* \return is_finish Is training finished or not * \return is_finish Is training finished or not
*/ */
bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const override; bool SaveModelToFile(int start_iteration, int num_iterations,
int feature_importance_type,
const char* filename) const override;
/*! /*!
* \brief Save model to string * \brief Save model to string
* \param start_iteration The model will be saved start from * \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Non-empty string if succeeded * \return Non-empty string if succeeded
*/ */
std::string SaveModelToString(int start_iteration, int num_iterations) const override; std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const override;
/*! /*!
* \brief Restore from a serialized buffer * \brief Restore from a serialized buffer
......
...@@ -18,7 +18,7 @@ namespace LightGBM { ...@@ -18,7 +18,7 @@ namespace LightGBM {
const char* kModelVersion = "v3"; const char* kModelVersion = "v3";
std::string GBDT::DumpModel(int start_iteration, int num_iteration) const { std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "{"; str_buf << "{";
...@@ -95,7 +95,8 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const { ...@@ -95,7 +95,8 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
} }
str_buf << "]," << '\n'; str_buf << "]," << '\n';
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0); std::vector<double> feature_importances = FeatureImportance(
num_iteration, feature_importance_type);
// store the importance first // store the importance first
std::vector<std::pair<size_t, std::string>> pairs; std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) { for (size_t i = 0; i < feature_importances.size(); ++i) {
...@@ -302,7 +303,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const { ...@@ -302,7 +303,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
return static_cast<bool>(output_file); return static_cast<bool>(output_file);
} }
std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const { std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int feature_importance_type) const {
std::stringstream ss; std::stringstream ss;
// output model type // output model type
...@@ -363,8 +364,8 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons ...@@ -363,8 +364,8 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons
tree_strs[i].clear(); tree_strs[i].clear();
} }
ss << "end of trees" << "\n"; ss << "end of trees" << "\n";
std::vector<double> feature_importances = FeatureImportance(
std::vector<double> feature_importances = FeatureImportance(num_iteration, 0); num_iteration, feature_importance_type);
// store the importance first // store the importance first
std::vector<std::pair<size_t, std::string>> pairs; std::vector<std::pair<size_t, std::string>> pairs;
for (size_t i = 0; i < feature_importances.size(); ++i) { for (size_t i = 0; i < feature_importances.size(); ++i) {
...@@ -395,11 +396,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons ...@@ -395,11 +396,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons
return ss.str(); return ss.str();
} }
bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const { bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) const {
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream output_file; std::ofstream output_file;
output_file.open(filename, std::ios::out | std::ios::binary); output_file.open(filename, std::ios::out | std::ios::binary);
std::string str_to_write = SaveModelToString(start_iteration, num_iteration); std::string str_to_write = SaveModelToString(start_iteration, num_iteration, feature_importance_type);
output_file.write(str_to_write.c_str(), str_to_write.size()); output_file.write(str_to_write.c_str(), str_to_write.size());
output_file.close(); output_file.close();
......
...@@ -689,8 +689,8 @@ class Booster { ...@@ -689,8 +689,8 @@ class Booster {
boosting_->GetPredictAt(data_idx, out_result, out_len); boosting_->GetPredictAt(data_idx, out_result, out_len);
} }
void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) { void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) {
boosting_->SaveModelToFile(start_iteration, num_iteration, filename); boosting_->SaveModelToFile(start_iteration, num_iteration, feature_importance_type, filename);
} }
void LoadModelFromString(const char* model_str) { void LoadModelFromString(const char* model_str) {
...@@ -698,12 +698,16 @@ class Booster { ...@@ -698,12 +698,16 @@ class Booster {
boosting_->LoadModelFromString(model_str, len); boosting_->LoadModelFromString(model_str, len);
} }
std::string SaveModelToString(int start_iteration, int num_iteration) { std::string SaveModelToString(int start_iteration, int num_iteration,
return boosting_->SaveModelToString(start_iteration, num_iteration); int feature_importance_type) {
return boosting_->SaveModelToString(start_iteration,
num_iteration, feature_importance_type);
} }
std::string DumpModel(int start_iteration, int num_iteration) { std::string DumpModel(int start_iteration, int num_iteration,
return boosting_->DumpModel(start_iteration, num_iteration); int feature_importance_type) {
return boosting_->DumpModel(start_iteration, num_iteration,
feature_importance_type);
} }
std::vector<double> FeatureImportance(int num_iteration, int importance_type) { std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
...@@ -2010,22 +2014,26 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -2010,22 +2014,26 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
int LGBM_BoosterSaveModel(BoosterHandle handle, int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(start_iteration, num_iteration, filename); ref_booster->SaveModelToFile(start_iteration, num_iteration,
feature_importance_type, filename);
API_END(); API_END();
} }
int LGBM_BoosterSaveModelToString(BoosterHandle handle, int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration); std::string model = ref_booster->SaveModelToString(
start_iteration, num_iteration, feature_importance_type);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::memcpy(out_str, model.c_str(), *out_len); std::memcpy(out_str, model.c_str(), *out_len);
...@@ -2036,12 +2044,14 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -2036,12 +2044,14 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int LGBM_BoosterDumpModel(BoosterHandle handle, int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(start_iteration, num_iteration); std::string model = ref_booster->DumpModel(start_iteration, num_iteration,
feature_importance_type);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::memcpy(out_str, model.c_str(), *out_len); std::memcpy(out_str, model.c_str(), *out_len);
......
...@@ -234,6 +234,7 @@ const std::unordered_set<std::string>& Config::parameter_set() { ...@@ -234,6 +234,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"verbosity", "verbosity",
"input_model", "input_model",
"output_model", "output_model",
"saved_feature_importance_type",
"snapshot_freq", "snapshot_freq",
"max_bin", "max_bin",
"max_bin_by_feature", "max_bin_by_feature",
...@@ -463,6 +464,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str ...@@ -463,6 +464,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetInt(params, "saved_feature_importance_type", &saved_feature_importance_type);
GetInt(params, "snapshot_freq", &snapshot_freq); GetInt(params, "snapshot_freq", &snapshot_freq);
GetInt(params, "max_bin", &max_bin); GetInt(params, "max_bin", &max_bin);
...@@ -664,6 +667,7 @@ std::string Config::SaveMembersToString() const { ...@@ -664,6 +667,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[path_smooth: " << path_smooth << "]\n"; str_buf << "[path_smooth: " << path_smooth << "]\n";
str_buf << "[interaction_constraints: " << interaction_constraints << "]\n"; str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n"; str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n"; str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n"; str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n"; str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
......
...@@ -37,16 +37,17 @@ ...@@ -37,16 +37,17 @@
char * LGBM_BoosterSaveModelToStringSWIG(BoosterHandle handle, char * LGBM_BoosterSaveModelToStringSWIG(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len) { int64_t* out_len) {
char* dst = new char[buffer_len]; char* dst = new char[buffer_len];
int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, buffer_len, out_len, dst); int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, feature_importance_type, buffer_len, out_len, dst);
// Reallocate to use larger length // Reallocate to use larger length
if (*out_len > buffer_len) { if (*out_len > buffer_len) {
delete [] dst; delete [] dst;
int64_t realloc_len = *out_len; int64_t realloc_len = *out_len;
dst = new char[realloc_len]; dst = new char[realloc_len];
result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, realloc_len, out_len, dst); result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, feature_importance_type, realloc_len, out_len, dst);
} }
if (result != 0) { if (result != 0) {
return nullptr; return nullptr;
...@@ -57,16 +58,17 @@ ...@@ -57,16 +58,17 @@
char * LGBM_BoosterDumpModelSWIG(BoosterHandle handle, char * LGBM_BoosterDumpModelSWIG(BoosterHandle handle,
int start_iteration, int start_iteration,
int num_iteration, int num_iteration,
int feature_importance_type,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len) { int64_t* out_len) {
char* dst = new char[buffer_len]; char* dst = new char[buffer_len];
int result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, buffer_len, out_len, dst); int result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, feature_importance_type, buffer_len, out_len, dst);
// Reallocate to use larger length // Reallocate to use larger length
if (*out_len > buffer_len) { if (*out_len > buffer_len) {
delete [] dst; delete [] dst;
int64_t realloc_len = *out_len; int64_t realloc_len = *out_len;
dst = new char[realloc_len]; dst = new char[realloc_len];
result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, realloc_len, out_len, dst); result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, feature_importance_type, realloc_len, out_len, dst);
} }
if (result != 0) { if (result != 0) {
return nullptr; return nullptr;
......
...@@ -236,7 +236,7 @@ def test_booster(): ...@@ -236,7 +236,7 @@ def test_booster():
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
if i % 10 == 0: if i % 10 == 0:
print('%d iteration test AUC %f' % (i, result[0])) print('%d iteration test AUC %f' % (i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, 0, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, 0, -1, 0, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
free_dataset(train) free_dataset(train)
free_dataset(test) free_dataset(test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment