Unverified Commit 8b720844 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package][R-package] load parameters from model file (fixes #2613) (#5424)

parent c134d3d9
...@@ -77,6 +77,7 @@ Booster <- R6::R6Class( ...@@ -77,6 +77,7 @@ Booster <- R6::R6Class(
LGBM_BoosterCreateFromModelfile_R LGBM_BoosterCreateFromModelfile_R
, modelfile , modelfile
) )
params <- private$get_loaded_param(handle)
} else if (!is.null(model_str)) { } else if (!is.null(model_str)) {
...@@ -727,6 +728,20 @@ Booster <- R6::R6Class( ...@@ -727,6 +728,20 @@ Booster <- R6::R6Class(
}, },
get_loaded_param = function(handle) {
params_str <- .Call(
LGBM_BoosterGetLoadedParam_R
, handle
)
params <- jsonlite::fromJSON(params_str)
if ("interaction_constraints" %in% names(params)) {
params[["interaction_constraints"]] <- lapply(params[["interaction_constraints"]], function(x) x + 1L)
}
return(params)
},
inner_eval = function(data_name, data_idx, feval = NULL) { inner_eval = function(data_name, data_idx, feval = NULL) {
# Check for unknown dataset (over the maximum provided range) # Check for unknown dataset (over the maximum provided range)
......
...@@ -1183,6 +1183,27 @@ SEXP LGBM_DumpParamAliases_R() { ...@@ -1183,6 +1183,27 @@ SEXP LGBM_DumpParamAliases_R() {
R_API_END(); R_API_END();
} }
SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP params_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
// if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
}
params_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(2);
return params_str;
R_API_END();
}
// .Call() calls // .Call() calls
static const R_CallMethodDef CallEntries[] = { static const R_CallMethodDef CallEntries[] = {
{"LGBM_HandleIsNull_R" , (DL_FUNC) &LGBM_HandleIsNull_R , 1}, {"LGBM_HandleIsNull_R" , (DL_FUNC) &LGBM_HandleIsNull_R , 1},
...@@ -1211,6 +1232,7 @@ static const R_CallMethodDef CallEntries[] = { ...@@ -1211,6 +1232,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2}, {"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2}, {"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1}, {"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterGetLoadedParam_R" , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1}, {"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4}, {"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1}, {"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
......
...@@ -266,6 +266,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R( ...@@ -266,6 +266,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R(
SEXP model_str SEXP model_str
); );
/*!
* \brief Get parameters as JSON string.
* \param handle Booster handle
* \return R character vector (length=1) with parameters in JSON format
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLoadedParam_R(
SEXP handle
);
/*! /*!
* \brief Merge model in two Boosters to first handle * \brief Merge model in two Boosters to first handle
* \param handle handle primary Booster handle, will merge other handle to this * \param handle handle primary Booster handle, will merge other handle to this
......
...@@ -172,15 +172,24 @@ test_that("Loading a Booster from a text file works", { ...@@ -172,15 +172,24 @@ test_that("Loading a Booster from a text file works", {
data(agaricus.test, package = "lightgbm") data(agaricus.test, package = "lightgbm")
train <- agaricus.train train <- agaricus.train
test <- agaricus.test test <- agaricus.test
bst <- lightgbm( params <- list(
data = as.matrix(train$data)
, label = train$label
, params = list(
num_leaves = 4L num_leaves = 4L
, boosting = "rf"
, bagging_fraction = 0.8
, bagging_freq = 1L
, boost_from_average = FALSE
, categorical_feature = c(1L, 2L)
, interaction_constraints = list(c(1L, 2L), 1L)
, feature_contri = rep(0.5, ncol(train$data))
, metric = c("mape", "average_precision")
, learning_rate = 1.0 , learning_rate = 1.0
, objective = "binary" , objective = "binary"
, verbose = VERBOSITY , verbosity = VERBOSITY
) )
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = params
, nrounds = 2L , nrounds = 2L
) )
expect_true(lgb.is.Booster(bst)) expect_true(lgb.is.Booster(bst))
...@@ -199,6 +208,9 @@ test_that("Loading a Booster from a text file works", { ...@@ -199,6 +208,9 @@ test_that("Loading a Booster from a text file works", {
) )
pred2 <- predict(bst2, test$data) pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2) expect_identical(pred, pred2)
# check that the parameters are loaded correctly
expect_equal(bst2$params[names(params)], params)
}) })
test_that("boosters with linear models at leaves can be written to text file and re-loaded successfully", { test_that("boosters with linear models at leaves can be written to text file and re-loaded successfully", {
......
...@@ -6,6 +6,7 @@ with list of all parameters, aliases table and other routines ...@@ -6,6 +6,7 @@ with list of all parameters, aliases table and other routines
along with parameters description in LightGBM/docs/Parameters.rst file along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file. from the information in LightGBM/include/LightGBM/config.h file.
""" """
import re
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -373,6 +374,32 @@ def gen_parameter_code( ...@@ -373,6 +374,32 @@ def gen_parameter_code(
} }
""" """
str_to_write += """const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
static std::unordered_map<std::string, std::string> map({"""
int_t_pat = re.compile(r'int\d+_t')
# the following are stored as comma separated strings but are arrays in the wrappers
overrides = {
'categorical_feature': 'vector<int>',
'ignore_column': 'vector<int>',
'interaction_constraints': 'vector<vector<int>>',
}
for x in infos:
for y in x:
name = y["name"][0]
if name == 'task':
continue
if name in overrides:
param_type = overrides[name]
else:
param_type = int_t_pat.sub('int', y["inner_type"][0]).replace('std::', '')
str_to_write += '\n {"' + name + '", "' + param_type + '"},'
str_to_write += """
});
return map;
}
"""
str_to_write += "} // namespace LightGBM\n" str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file: with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write) config_out_cpp_file.write(str_to_write)
......
...@@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting { ...@@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting {
*/ */
static Boosting* CreateBoosting(const std::string& type, const char* filename); static Boosting* CreateBoosting(const std::string& type, const char* filename);
virtual std::string GetLoadedParam() const = 0;
virtual bool IsLinear() const { return false; } virtual bool IsLinear() const { return false; }
virtual std::string ParserConfigStr() const = 0; virtual std::string ParserConfigStr() const = 0;
......
...@@ -595,6 +595,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str, ...@@ -595,6 +595,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str,
int* out_num_iterations, int* out_num_iterations,
BoosterHandle* out); BoosterHandle* out);
/*!
* \brief Get parameters as JSON string.
* \param handle Handle of booster.
* \param buffer_len Allocated space for string.
* \param[out] out_len Actual size of string.
* \param[out] out_str JSON string containing parameters.
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLoadedParam(BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
/*! /*!
* \brief Free space for booster. * \brief Free space for booster.
* \param handle Handle of booster to be freed * \param handle Handle of booster to be freed
......
...@@ -1077,6 +1077,7 @@ struct Config { ...@@ -1077,6 +1077,7 @@ struct Config {
static const std::unordered_set<std::string>& parameter_set(); static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix; std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector; std::vector<std::vector<int>> interaction_constraints_vector;
static const std::unordered_map<std::string, std::string>& ParameterTypes();
static const std::string DumpAliases(); static const std::string DumpAliases();
private: private:
......
...@@ -2816,6 +2816,9 @@ class Booster: ...@@ -2816,6 +2816,9 @@ class Booster:
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file) self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
if params:
_log_warning('Ignoring params argument, using parameters from model file.')
params = self._get_loaded_param()
elif model_str is not None: elif model_str is not None:
self.model_from_string(model_str) self.model_from_string(model_str)
else: else:
...@@ -2864,6 +2867,28 @@ class Booster: ...@@ -2864,6 +2867,28 @@ class Booster:
state['handle'] = handle state['handle'] = handle
self.__dict__.update(state) self.__dict__.update(state)
def _get_loaded_param(self) -> Dict[str, Any]:
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterGetLoadedParam(
self.handle,
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterGetLoadedParam(
self.handle,
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
return json.loads(string_buffer.value.decode('utf-8'))
def free_dataset(self) -> "Booster": def free_dataset(self) -> "Booster":
"""Free Booster's Datasets. """Free Booster's Datasets.
......
...@@ -157,6 +157,60 @@ class GBDT : public GBDTBase { ...@@ -157,6 +157,60 @@ class GBDT : public GBDTBase {
*/ */
int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; } int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; }
/*!
* \brief Get parameters as a JSON string
*/
std::string GetLoadedParam() const override {
if (loaded_parameter_.empty()) {
return std::string("{}");
}
const auto param_types = Config::ParameterTypes();
const auto lines = Common::Split(loaded_parameter_.c_str(), "\n");
bool first = true;
std::stringstream str_buf;
str_buf << "{";
for (const auto& line : lines) {
const auto pair = Common::Split(line.c_str(), ":");
if (pair[1] == " ]")
continue;
if (first) {
first = false;
str_buf << "\"";
} else {
str_buf << ",\"";
}
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
const auto param_type = param_types.at(param);
str_buf << param << "\": ";
if (param_type == "string") {
str_buf << "\"" << value_str << "\"";
} else if (param_type == "int") {
int value;
Common::Atoi(value_str.c_str(), &value);
str_buf << value;
} else if (param_type == "double") {
double value;
Common::Atof(value_str.c_str(), &value);
str_buf << value;
} else if (param_type == "bool") {
bool value = value_str == "1";
str_buf << std::boolalpha << value;
} else if (param_type.substr(0, 6) == "vector") {
str_buf << "[";
if (param_type.substr(7, 6) == "string") {
const auto parts = Common::Split(value_str.c_str(), ",");
str_buf << "\"" << Common::Join(parts, "\",\"") << "\"";
} else {
str_buf << value_str;
}
str_buf << "]";
}
}
str_buf << "}";
return str_buf.str();
}
/*! /*!
* \brief Can use early stopping for prediction or not * \brief Can use early stopping for prediction or not
* \return True if cannot use early stopping for prediction * \return True if cannot use early stopping for prediction
......
...@@ -1748,6 +1748,21 @@ int LGBM_BoosterLoadModelFromString( ...@@ -1748,6 +1748,21 @@ int LGBM_BoosterLoadModelFromString(
API_END(); API_END();
} }
int LGBM_BoosterGetLoadedParam(
BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string params = ref_booster->GetBoosting()->GetLoadedParam();
*out_len = static_cast<int64_t>(params.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, params.c_str(), *out_len);
}
API_END();
}
#ifdef _MSC_VER #ifdef _MSC_VER
#pragma warning(disable : 4702) #pragma warning(disable : 4702)
#endif #endif
......
...@@ -894,4 +894,141 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet ...@@ -894,4 +894,141 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
return map; return map;
} }
const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
static std::unordered_map<std::string, std::string> map({
{"config", "string"},
{"objective", "string"},
{"boosting", "string"},
{"data", "string"},
{"valid", "vector<string>"},
{"num_iterations", "int"},
{"learning_rate", "double"},
{"num_leaves", "int"},
{"tree_learner", "string"},
{"num_threads", "int"},
{"device_type", "string"},
{"seed", "int"},
{"deterministic", "bool"},
{"force_col_wise", "bool"},
{"force_row_wise", "bool"},
{"histogram_pool_size", "double"},
{"max_depth", "int"},
{"min_data_in_leaf", "int"},
{"min_sum_hessian_in_leaf", "double"},
{"bagging_fraction", "double"},
{"pos_bagging_fraction", "double"},
{"neg_bagging_fraction", "double"},
{"bagging_freq", "int"},
{"bagging_seed", "int"},
{"feature_fraction", "double"},
{"feature_fraction_bynode", "double"},
{"feature_fraction_seed", "int"},
{"extra_trees", "bool"},
{"extra_seed", "int"},
{"early_stopping_round", "int"},
{"first_metric_only", "bool"},
{"max_delta_step", "double"},
{"lambda_l1", "double"},
{"lambda_l2", "double"},
{"linear_lambda", "double"},
{"min_gain_to_split", "double"},
{"drop_rate", "double"},
{"max_drop", "int"},
{"skip_drop", "double"},
{"xgboost_dart_mode", "bool"},
{"uniform_drop", "bool"},
{"drop_seed", "int"},
{"top_rate", "double"},
{"other_rate", "double"},
{"min_data_per_group", "int"},
{"max_cat_threshold", "int"},
{"cat_l2", "double"},
{"cat_smooth", "double"},
{"max_cat_to_onehot", "int"},
{"top_k", "int"},
{"monotone_constraints", "vector<int>"},
{"monotone_constraints_method", "string"},
{"monotone_penalty", "double"},
{"feature_contri", "vector<double>"},
{"forcedsplits_filename", "string"},
{"refit_decay_rate", "double"},
{"cegb_tradeoff", "double"},
{"cegb_penalty_split", "double"},
{"cegb_penalty_feature_lazy", "vector<double>"},
{"cegb_penalty_feature_coupled", "vector<double>"},
{"path_smooth", "double"},
{"interaction_constraints", "vector<vector<int>>"},
{"verbosity", "int"},
{"input_model", "string"},
{"output_model", "string"},
{"saved_feature_importance_type", "int"},
{"snapshot_freq", "int"},
{"linear_tree", "bool"},
{"max_bin", "int"},
{"max_bin_by_feature", "vector<int>"},
{"min_data_in_bin", "int"},
{"bin_construct_sample_cnt", "int"},
{"data_random_seed", "int"},
{"is_enable_sparse", "bool"},
{"enable_bundle", "bool"},
{"use_missing", "bool"},
{"zero_as_missing", "bool"},
{"feature_pre_filter", "bool"},
{"pre_partition", "bool"},
{"two_round", "bool"},
{"header", "bool"},
{"label_column", "string"},
{"weight_column", "string"},
{"group_column", "string"},
{"ignore_column", "vector<int>"},
{"categorical_feature", "vector<int>"},
{"forcedbins_filename", "string"},
{"save_binary", "bool"},
{"precise_float_parser", "bool"},
{"parser_config_file", "string"},
{"start_iteration_predict", "int"},
{"num_iteration_predict", "int"},
{"predict_raw_score", "bool"},
{"predict_leaf_index", "bool"},
{"predict_contrib", "bool"},
{"predict_disable_shape_check", "bool"},
{"pred_early_stop", "bool"},
{"pred_early_stop_freq", "int"},
{"pred_early_stop_margin", "double"},
{"output_result", "string"},
{"convert_model_language", "string"},
{"convert_model", "string"},
{"objective_seed", "int"},
{"num_class", "int"},
{"is_unbalance", "bool"},
{"scale_pos_weight", "double"},
{"sigmoid", "double"},
{"boost_from_average", "bool"},
{"reg_sqrt", "bool"},
{"alpha", "double"},
{"fair_c", "double"},
{"poisson_max_delta_step", "double"},
{"tweedie_variance_power", "double"},
{"lambdarank_truncation_level", "int"},
{"lambdarank_norm", "bool"},
{"label_gain", "vector<double>"},
{"metric", "vector<string>"},
{"metric_freq", "int"},
{"is_provide_training_metric", "bool"},
{"eval_at", "vector<int>"},
{"multi_error_top_k", "int"},
{"auc_mu_weights", "vector<double>"},
{"num_machines", "int"},
{"local_listen_port", "int"},
{"time_out", "int"},
{"machine_list_filename", "string"},
{"machines", "string"},
{"gpu_platform_id", "int"},
{"gpu_device_id", "int"},
{"gpu_use_dp", "bool"},
{"num_gpu", "int"},
});
return map;
}
} // namespace LightGBM } // namespace LightGBM
...@@ -1211,6 +1211,35 @@ def test_feature_name_with_non_ascii(): ...@@ -1211,6 +1211,35 @@ def test_feature_name_with_non_ascii():
assert feature_names == gbm2.feature_name() assert feature_names == gbm2.feature_name()
def test_parameters_are_loaded_from_model_file(tmp_path):
X = np.hstack([np.random.rand(100, 1), np.random.randint(0, 5, (100, 2))])
y = np.random.rand(100)
ds = lgb.Dataset(X, y)
params = {
'bagging_fraction': 0.8,
'bagging_freq': 2,
'boosting': 'rf',
'feature_contri': [0.5, 0.5, 0.5],
'feature_fraction': 0.7,
'boost_from_average': False,
'interaction_constraints': [[0, 1], [0]],
'metric': ['l2', 'rmse'],
'num_leaves': 5,
'num_threads': 1,
}
model_file = tmp_path / 'model.txt'
lgb.train(params, ds, num_boost_round=1, categorical_feature=[1, 2]).save_model(model_file)
bst = lgb.Booster(model_file=model_file)
set_params = {k: bst.params[k] for k in params.keys()}
assert set_params == params
assert bst.params['categorical_feature'] == [1, 2]
# check that passing parameters to the constructor raises warning and ignores them
with pytest.warns(UserWarning, match='Ignoring params argument'):
bst2 = lgb.Booster(params={'num_leaves': 7}, model_file=model_file)
assert bst.params == bst2.params
def test_save_load_copy_pickle(): def test_save_load_copy_pickle():
def train_and_predict(init_model=None, return_model=False): def train_and_predict(init_model=None, return_model=False):
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment