Unverified Commit 9f79e840 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] [R-package] refine the parameters for Dataset (#2594)



* reset

* fix a bug

* fix test

* Update c_api.h

* support to no filter features by min_data

* add warning in reset config

* refine warnings for override dataset's parameter

* some cleans

* clean code

* clean code

* refine C API function doxygen comments

* refined new param description

* refined doxygen comments for R API function

* removed stuff related to int8

* break long line in warning message

* removed tests which results cannot be validated anymore

* added test for warnings about unchangeable params

* write parameter from dataset to booster

* consider free_raw_data.

* fix params

* fix bug

* implementing R

* fix typo

* filter params in R

* fix R

* not min_data

* refined tests

* fixed linting

* refine

* pilint

* add docstring

* fix docstring

* R lint

* updated description for C API function

* use param aliases in Python

* fixed typo

* fixed typo

* added more params to test

* removed debug print

* fix dataset construct place

* fix merge bug

* Update feature_histogram.hpp

* add is_sparse back

* remove unused parameters

* fix lint

* add data random seed

* update

* [R-package] centrallized Dataset parameter aliases and added tests on Dataset parameter updating (#2767)
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent fed09d33
# Central location for parameter aliases.
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#core-parameters
# [description] List of respected parameter aliases specific to lgb.Dataset. Wrapped in a function to
# take advantage of lazy evaluation (so it doesn't matter what order
# R sources files during installation).
# [return] A named list, where each key is a parameter relevant to lgb.DataSet and each value is a character
# vector of corresponding aliases.
.DATASET_PARAMETERS <- function() {
return(list(
"bin_construct_sample_cnt" = c(
"bin_construct_sample_cnt"
, "subsample_for_bin"
)
, "categorical_feature" = c(
"categorical_feature"
, "cat_feature"
, "categorical_column"
, "cat_column"
)
, "seed" = c(
"seed"
, "data_random_seed"
, "feature_fraction_seed"
)
, "enable_bundle" = c(
"enable_bundle"
, "is_endable_bundle"
, "bundle"
)
, "enable_sparse" = c(
"enable_sparse"
, "is_sparse"
, "sparse"
)
, "feature_pre_filter" = "feature_pre_filter"
, "forcedbins_filename" = "forcedbins_filename"
, "group_column" = c(
"group_column"
, "group_id"
, "query_column"
, "query"
, "query_id"
)
, "header" = c(
"header"
, "has_header"
)
, "ignore_column" = c(
"ignore_column"
, "ignore_feature"
, "blacklist"
)
, "label_column" = c(
"label_column"
, "label"
)
, "max_bin" = "max_bin"
, "max_bin_by_feature" = "max_bin_by_feature"
, "pre_partition" = c(
"pre_parition"
, "is_pre_partition"
)
, "two_round" = c(
"two_round"
, "two_round_loading"
, "use_two_round_loading"
)
, "use_missing" = "use_missing"
, "weight_column" = c(
"weight_column"
, "weight"
)
, "zero_as_missing" = "zero_as_missing"
))
}
# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding aliases.
.PARAMETER_ALIASES <- function() {
return(list(
learning_params <- list(
"boosting" = c(
"boosting"
, "boost"
......@@ -29,5 +103,6 @@
, "num_boost_round"
, "n_estimators"
)
))
)
return(c(learning_params, .DATASET_PARAMETERS()))
}
......@@ -31,7 +31,6 @@ Booster <- R6::R6Class(
# Create parameters and handle
params <- append(params, list(...))
params_str <- lgb.params2str(params)
handle <- 0.0
# Attempts to create a handle for the dataset
......@@ -39,17 +38,18 @@ Booster <- R6::R6Class(
# Check if training dataset is not null
if (!is.null(train_set)) {
# Check if training dataset is lgb.Dataset or not
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) {
stop("lgb.Booster: Can only use lgb.Dataset as training data")
}
train_set_handle <- train_set$.__enclos_env__$private$get_handle()
params <- modifyList(params, train_set$get_params())
params_str <- lgb.params2str(params)
# Store booster handle
handle <- lgb.call(
"LGBM_BoosterCreate_R"
, ret = handle
, train_set$.__enclos_env__$private$get_handle()
, train_set_handle
, params_str
)
......
......@@ -530,22 +530,48 @@ Dataset <- R6::R6Class(
# Update parameters
update_params = function(params) {
# Parameter updating
if (!lgb.is.null.handle(private$handle)) {
lgb.call(
"LGBM_DatasetUpdateParam_R"
, ret = NULL
, private$handle
if (length(params) == 0L) {
return(invisible(self))
}
if (lgb.is.null.handle(private$handle)) {
private$params <- modifyList(private$params, params)
} else {
call_state <- 0L
call_state <- .Call(
"LGBM_DatasetUpdateParamChecking_R"
, lgb.params2str(private$params)
, lgb.params2str(params)
, call_state
, PACKAGE = "lib_lightgbm"
)
return(invisible(self))
call_state <- as.integer(call_state)
if (call_state != 0L) {
# raise error if raw data is freed
if (is.null(private$raw_data)) {
lgb.last_error()
}
# Overwrite paramms
private$params <- modifyList(private$params, params)
self$finalize()
}
}
private$params <- modifyList(private$params, params)
return(invisible(self))
},
get_params = function() {
dataset_params <- unname(unlist(.DATASET_PARAMETERS()))
ret <- list()
for (param_key in names(private$params)) {
if (param_key %in% dataset_params) {
ret[[param_key]] <- private$params[[param_key]]
}
}
return(ret)
},
# Set categorical feature parameter
set_categorical_feature = function(categorical_feature) {
......
......@@ -19,6 +19,36 @@ lgb.encode.char <- function(arr, len) {
}
lgb.last_error <- function() {
# Perform text error buffering
buf_len <- 200L
act_len <- 0L
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)
# Check error buffer
if (act_len > buf_len) {
buf_len <- act_len
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)
}
# Return error
stop("api error: ", lgb.encode.char(err_msg, act_len))
}
lgb.call <- function(fun_name, ret, ...) {
# Set call state to a zero value
call_state <- 0L
......@@ -43,35 +73,7 @@ lgb.call <- function(fun_name, ret, ...) {
call_state <- as.integer(call_state)
# Check for call state value post call
if (call_state != 0L) {
# Perform text error buffering
buf_len <- 200L
act_len <- 0L
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)
# Check error buffer
if (act_len > buf_len) {
buf_len <- act_len
err_msg <- raw(buf_len)
err_msg <- .Call(
"LGBM_GetLastError_R"
, buf_len
, act_len
, err_msg
, PACKAGE = "lib_lightgbm"
)
}
# Return error
stop("api error: ", lgb.encode.char(err_msg, act_len))
lgb.last_error()
}
return(ret)
......
......@@ -126,3 +126,80 @@ test_that("Dataset$new() should throw an error if 'predictor' is provided but of
)
}, regexp = "predictor must be a", fixed = TRUE)
})
test_that("Dataset$get_params() successfully returns parameters if you passed them", {
# note that this list uses one "main" parameter (feature_pre_filter) and one that
# is an alias (is_sparse), to check that aliases are handled correctly
params <- list(
"feature_pre_filter" = TRUE
, "is_sparse" = FALSE
)
ds <- lgb.Dataset(
test_data
, label = test_label
, params = params
)
returned_params <- ds$get_params()
expect_true(methods::is(returned_params, "list"))
expect_identical(length(params), length(returned_params))
expect_identical(sort(names(params)), sort(names(returned_params)))
for (param_name in names(params)) {
expect_identical(params[[param_name]], returned_params[[param_name]])
}
})
test_that("Dataset$get_params() ignores irrelevant parameters", {
params <- list(
"feature_pre_filter" = TRUE
, "is_sparse" = FALSE
, "nonsense_parameter" = c(1.0, 2.0, 5.0)
)
ds <- lgb.Dataset(
test_data
, label = test_label
, params = params
)
returned_params <- ds$get_params()
expect_false("nonsense_parameter" %in% names(returned_params))
})
test_that("Dataset$update_parameters() does nothing for empty inputs", {
ds <- lgb.Dataset(
test_data
, label = test_label
)
initial_params <- ds$get_params()
expect_identical(initial_params, list())
# update_params() should return "self" so it can be chained
res <- ds$update_params(
params = list()
)
expect_true(lgb.is.Dataset(res))
new_params <- ds$get_params()
expect_identical(new_params, initial_params)
})
test_that("Dataset$update_params() works correctly for recognized Dataset parameters", {
ds <- lgb.Dataset(
test_data
, label = test_label
)
initial_params <- ds$get_params()
expect_identical(initial_params, list())
new_params <- list(
"data_random_seed" = 708L
, "enable_bundle" = FALSE
)
res <- ds$update_params(
params = new_params
)
expect_true(lgb.is.Dataset(res))
updated_params <- ds$get_params()
for (param_name in names(new_params)) {
expect_identical(new_params[[param_name]], updated_params[[param_name]])
}
})
......@@ -44,13 +44,18 @@ test_that("Feature penalties work properly", {
expect_length(var_gain[[length(var_gain)]], 0L)
})
test_that(".PARAMETER_ALIASES() returns a named list", {
context("parameter aliases")
test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where names are unique", {
param_aliases <- .PARAMETER_ALIASES()
expect_true(is.list(param_aliases))
expect_true(is.character(names(param_aliases)))
expect_true(is.character(param_aliases[["boosting"]]))
expect_true(is.character(param_aliases[["early_stopping_round"]]))
expect_true(is.character(param_aliases[["num_iterations"]]))
expect_true(length(names(param_aliases)) == length(param_aliases))
expect_true(all(sapply(param_aliases, is.character)))
expect_true(length(unique(names(param_aliases))) == length(param_aliases))
})
test_that("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
......
......@@ -535,6 +535,14 @@ IO Parameters
- use this to avoid one-data-one-bin (potential over-fitting)
- ``feature_pre_filter`` :raw-html:`<a id="feature_pre_filter" title="Permalink to this parameter" href="#feature_pre_filter">&#x1F517;&#xFE0E;</a>`, default = ``true``, type = bool
- set this to ``true`` to pre-filter the unsplittable features by ``min_data_in_leaf``
- as dataset object is initialized only once and cannot be changed after that, you may need to set this to ``false`` when searching parameters with ``min_data_in_leaf``, otherwise features are filtered by ``min_data_in_leaf`` firstly if you don't reconstruct dataset object
- **Note**: setting this to ``false`` may slow down the training
- ``bin_construct_sample_cnt`` :raw-html:`<a id="bin_construct_sample_cnt" title="Permalink to this parameter" href="#bin_construct_sample_cnt">&#x1F517;&#xFE0E;</a>`, default = ``200000``, type = int, aliases: ``subsample_for_bin``, constraints: ``bin_construct_sample_cnt > 0``
- number of data that sampled to construct histogram bins
......
......@@ -142,12 +142,13 @@ class BinMapper {
* \param max_bin The maximal number of bin
* \param min_data_in_bin min number of data in one bin
* \param min_split_data
* \param pre_filter
* \param bin_type Type of this bin
* \param use_missing True to enable missing value handle
* \param zero_as_missing True to use zero as missing value
* \param forced_upper_bounds Vector of split points that must be used (if this has size less than max_bin, remaining splits are found by the algorithm)
*/
void FindBin(double* values, int num_values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type,
void FindBin(double* values, int num_values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, bool pre_filter, BinType bin_type,
bool use_missing, bool zero_as_missing, const std::vector<double>& forced_upper_bounds);
/*!
......
......@@ -26,7 +26,6 @@ typedef void* BoosterHandle; /*!< \brief Handle of booster. */
#define C_API_DTYPE_FLOAT64 (1) /*!< \brief float64 (double precision float). */
#define C_API_DTYPE_INT32 (2) /*!< \brief int32. */
#define C_API_DTYPE_INT64 (3) /*!< \brief int64. */
#define C_API_DTYPE_INT8 (4) /*!< \brief int8. */
#define C_API_PREDICT_NORMAL (0) /*!< \brief Normal prediction, with transform (if needed). */
#define C_API_PREDICT_RAW_SCORE (1) /*!< \brief Predict raw score. */
......@@ -331,7 +330,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
* \param field_name Field name
* \param[out] out_len Used to set result length
* \param[out] out_ptr Pointer to the result
* \param[out] out_type Type of result pointer, can be ``C_API_DTYPE_INT8``, ``C_API_DTYPE_INT32``, ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param[out] out_type Type of result pointer, can be ``C_API_DTYPE_INT32``, ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
......@@ -341,12 +340,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
int* out_type);
/*!
* \brief Update parameters for a dataset.
* \param handle Handle of dataset
* \param parameters Parameters
* \brief Raise errors for attempts to update dataset parameters.
* \param old_parameters Current dataset parameters
* \param new_parameters New dataset parameters
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetUpdateParam(DatasetHandle handle,
const char* parameters);
LIGHTGBM_C_EXPORT int LGBM_DatasetUpdateParamChecking(const char* old_parameters,
const char* new_parameters);
/*!
* \brief Get number of data points.
......
......@@ -505,6 +505,11 @@ struct Config {
// desc = use this to avoid one-data-one-bin (potential over-fitting)
int min_data_in_bin = 3;
// desc = set this to ``true`` to pre-filter the unsplittable features by ``min_data_in_leaf``
// desc = as dataset object is initialized only once and cannot be changed after that, you may need to set this to ``false`` when searching parameters with ``min_data_in_leaf``, otherwise features are filtered by ``min_data_in_leaf`` firstly if you don't reconstruct dataset object
// desc = **Note**: setting this to ``false`` may slow down the training
bool feature_pre_filter = true;
// alias = subsample_for_bin
// check = >0
// desc = number of data that sampled to construct histogram bins
......
......@@ -465,8 +465,6 @@ class Dataset {
LIGHTGBM_EXPORT bool GetIntField(const char* field_name, data_size_t* out_len, const int** out_ptr);
LIGHTGBM_EXPORT bool GetInt8Field(const char* field_name, data_size_t* out_len, const int8_t** out_ptr);
/*!
* \brief Save current dataset into binary file, will save to "filename.bin"
*/
......@@ -524,35 +522,6 @@ class Dataset {
return feature_groups_[group]->bin_mappers_[sub_feature]->num_bin();
}
inline int8_t FeatureMonotone(int i) const {
if (monotone_types_.empty()) {
return 0;
} else {
return monotone_types_[i];
}
}
inline double FeaturePenalte(int i) const {
if (feature_penalty_.empty()) {
return 1;
} else {
return feature_penalty_[i];
}
}
bool HasMonotone() const {
if (monotone_types_.empty()) {
return false;
} else {
for (size_t i = 0; i < monotone_types_.size(); ++i) {
if (monotone_types_[i] != 0) {
return true;
}
}
return false;
}
}
inline int FeatureGroupNumBin(int group) const {
return feature_groups_[group]->num_total_bin_;
}
......@@ -660,8 +629,6 @@ class Dataset {
return bufs;
}
void ResetConfig(const char* parameters);
/*! \brief Get Number of data */
inline data_size_t num_data() const { return num_data_; }
......@@ -699,8 +666,6 @@ class Dataset {
std::vector<uint64_t> group_bin_boundaries_;
std::vector<int> group_feature_start_;
std::vector<int> group_feature_cnt_;
std::vector<int8_t> monotone_types_;
std::vector<double> feature_penalty_;
bool is_finish_load_;
int max_bin_;
std::vector<int32_t> max_bin_by_feature_;
......
......@@ -175,13 +175,13 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetGetField_R(LGBM_SE handle,
LGBM_SE call_state);
/*!
* \brief Update parameters for a Dataset
* \param handle an instance of data matrix
* \param parameters parameters
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetUpdateParam_R(LGBM_SE handle,
LGBM_SE params,
* \brief Raise errors for attempts to update dataset parameters
* \param old_params Current dataset parameters
* \param new_params New dataset parameters
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT LGBM_SE LGBM_DatasetUpdateParamChecking_R(LGBM_SE old_params,
LGBM_SE new_params,
LGBM_SE call_state);
/*!
......
......@@ -109,14 +109,6 @@ def cint32_array_to_numpy(cptr, length):
raise RuntimeError('Expected int pointer')
def cint8_array_to_numpy(cptr, length):
"""Convert a ctypes int pointer array to a numpy array."""
if isinstance(cptr, ctypes.POINTER(ctypes.c_int8)):
return np.fromiter(cptr, dtype=np.int8, count=length)
else:
raise RuntimeError('Expected int pointer')
def c_str(string):
"""Convert a Python string to C string."""
return ctypes.c_char_p(string.encode('utf-8'))
......@@ -170,24 +162,46 @@ class LightGBMError(Exception):
class _ConfigAliases(object):
aliases = {"boosting": {"boosting",
aliases = {"bin_construct_sample_cnt": {"bin_construct_sample_cnt",
"subsample_for_bin"},
"boosting": {"boosting",
"boosting_type",
"boost"},
"categorical_feature": {"categorical_feature",
"cat_feature",
"categorical_column",
"cat_column"},
"data_random_seed": {"data_random_seed",
"data_seed"},
"early_stopping_round": {"early_stopping_round",
"early_stopping_rounds",
"early_stopping",
"n_iter_no_change"},
"enable_bundle": {"enable_bundle",
"is_enable_bundle",
"bundle"},
"eval_at": {"eval_at",
"ndcg_eval_at",
"ndcg_at",
"map_eval_at",
"map_at"},
"group_column": {"group_column",
"group",
"group_id",
"query_column",
"query",
"query_id"},
"header": {"header",
"has_header"},
"ignore_column": {"ignore_column",
"ignore_feature",
"blacklist"},
"is_enable_sparse": {"is_enable_sparse",
"is_sparse",
"enable_sparse",
"sparse"},
"label_column": {"label_column",
"label"},
"machines": {"machines",
"workers",
"nodes"},
......@@ -209,14 +223,21 @@ class _ConfigAliases(object):
"objective_type",
"app",
"application"},
"pre_partition": {"pre_partition",
"is_pre_partition"},
"two_round": {"two_round",
"two_round_loading",
"use_two_round_loading"},
"verbosity": {"verbosity",
"verbose"}}
"verbose"},
"weight_column": {"weight_column",
"weight"}}
@classmethod
def get(cls, *args):
ret = set()
for i in args:
ret |= cls.aliases.get(i, set())
ret |= cls.aliases.get(i, {i})
return ret
......@@ -227,7 +248,6 @@ C_API_DTYPE_FLOAT32 = 0
C_API_DTYPE_FLOAT64 = 1
C_API_DTYPE_INT32 = 2
C_API_DTYPE_INT64 = 3
C_API_DTYPE_INT8 = 4
"""Matrix is row major in Python"""
C_API_IS_ROW_MAJOR = 1
......@@ -242,9 +262,7 @@ C_API_PREDICT_CONTRIB = 3
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
"weight": C_API_DTYPE_FLOAT32,
"init_score": C_API_DTYPE_FLOAT64,
"group": C_API_DTYPE_INT32,
"feature_penalty": C_API_DTYPE_FLOAT64,
"monotone_constraints": C_API_DTYPE_INT8}
"group": C_API_DTYPE_INT32}
def convert_from_sliced_object(data):
......@@ -779,6 +797,37 @@ class Dataset(object):
except AttributeError:
pass
def get_params(self):
"""Get the used parameters in the Dataset.
Returns
-------
params : dict or None
The used parameters in this Dataset object.
"""
if self.params is not None:
# no min_data, nthreads and verbose in this function
dataset_params = _ConfigAliases.get("bin_construct_sample_cnt",
"categorical_feature",
"data_random_seed",
"enable_bundle",
"feature_pre_filter",
"forcedbins_filename",
"group_column",
"header",
"ignore_column",
"is_enable_sparse",
"label_column",
"max_bin",
"max_bin_by_feature",
"min_data_in_bin",
"pre_partition",
"two_round",
"use_missing",
"weight_column",
"zero_as_missing")
return {k: v for k, v in self.params.items() if k in dataset_params}
def _free_handle(self):
if self.handle is not None:
_safe_call(_LIB.LGBM_DatasetFree(self.handle))
......@@ -867,6 +916,7 @@ class Dataset(object):
params['categorical_column'] = sorted(categorical_indices)
params_str = param_dict_to_str(params)
self.params = params
# process for reference dataset
ref_dataset = None
if isinstance(reference, Dataset):
......@@ -1172,20 +1222,34 @@ class Dataset(object):
return self
def _update_params(self, params):
if self.handle is not None and params is not None:
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(params))))
if not self.params:
self.params = copy.deepcopy(params)
else:
self.params_back_up = copy.deepcopy(self.params)
self.params.update(params)
params = copy.deepcopy(params)
def update():
if not self.params:
self.params = params
else:
self.params_back_up = copy.deepcopy(self.params)
self.params.update(params)
if self.handle is None:
update()
elif params is not None:
ret = _LIB.LGBM_DatasetUpdateParamChecking(
c_str(param_dict_to_str(self.params)),
c_str(param_dict_to_str(params)))
if ret != 0:
# could be updated if data is not freed
if self.data is not None:
update()
self._free_handle()
else:
raise LightGBMError(decode_string(_LIB.LGBM_GetLastError()))
return self
def _reverse_update_params(self):
self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None
if self.handle is not None and self.params is not None:
_safe_call(_LIB.LGBM_DatasetUpdateParam(self.handle, c_str(param_dict_to_str(self.params))))
if self.handle is None:
self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None
return self
def set_field(self, field_name, data):
......@@ -1271,8 +1335,6 @@ class Dataset(object):
return cfloat32_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), tmp_out_len.value)
elif out_type.value == C_API_DTYPE_FLOAT64:
return cfloat64_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), tmp_out_len.value)
elif out_type.value == C_API_DTYPE_INT8:
return cint8_array_to_numpy(ctypes.cast(ret, ctypes.POINTER(ctypes.c_int8)), tmp_out_len.value)
else:
raise TypeError("Unknown type")
......@@ -1481,30 +1543,6 @@ class Dataset(object):
self.weight = self.get_field('weight')
return self.weight
def get_feature_penalty(self):
"""Get the feature penalty of the Dataset.
Returns
-------
feature_penalty : numpy array or None
Feature penalty for each feature in the Dataset.
"""
if self.feature_penalty is None:
self.feature_penalty = self.get_field('feature_penalty')
return self.feature_penalty
def get_monotone_constraints(self):
"""Get the monotone constraints of the Dataset.
Returns
-------
monotone_constraints : numpy array or None
Monotone constraints: -1, 0 or 1, for each feature in the Dataset.
"""
if self.monotone_constraints is None:
self.monotone_constraints = self.get_field('monotone_constraints')
return self.monotone_constraints
def get_init_score(self):
"""Get the initial score of the Dataset.
......@@ -1699,7 +1737,6 @@ class Booster(object):
if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
params_str = param_dict_to_str(params)
# set network if necessary
for alias in _ConfigAliases.get("machines"):
if alias in params:
......@@ -1717,9 +1754,13 @@ class Booster(object):
num_machines=params.get("num_machines", num_machines))
break
# construct booster object
train_set.construct()
# copy the parameters from train_set
params.update(train_set.get_params())
params_str = param_dict_to_str(params)
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreate(
train_set.construct().handle,
train_set.handle,
c_str(params_str),
ctypes.byref(self.handle)))
# save reference to data
......
......@@ -520,16 +520,17 @@ def cv(params, train_set, num_boost_round=100,
predictor = init_model._to_predictor(dict(init_model.params, **params))
else:
predictor = None
train_set._update_params(params) \
._set_predictor(predictor) \
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
if metrics is not None:
for metric_alias in _ConfigAliases.get("metric"):
params.pop(metric_alias, None)
params['metric'] = metrics
train_set._update_params(params) \
._set_predictor(predictor) \
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)
results = collections.defaultdict(list)
cvfolds = _make_n_folds(train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
......
......@@ -46,6 +46,12 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
const std::vector<const Metric*>& training_metrics) {
CHECK(train_data != nullptr);
train_data_ = train_data;
if (!config->monotone_constraints.empty()) {
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->monotone_constraints.size());
}
if (!config->feature_contri.empty()) {
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->feature_contri.size());
}
iter_ = 0;
num_iteration_for_pred_ = 0;
max_feature_idx_ = 0;
......@@ -717,6 +723,12 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
void GBDT::ResetConfig(const Config* config) {
auto new_config = std::unique_ptr<Config>(new Config(*config));
if (!config->monotone_constraints.empty()) {
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->monotone_constraints.size());
}
if (!config->feature_contri.empty()) {
CHECK(static_cast<size_t>(train_data_->num_total_features()) == config->feature_contri.size());
}
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) {
......
......@@ -173,6 +173,114 @@ class Booster {
}
}
static void CheckDatasetResetConfig(
const Config& old_config,
const std::unordered_map<std::string, std::string>& new_param) {
Config new_config;
new_config.Set(new_param);
if (new_param.count("data_random_seed") &&
new_config.data_random_seed != old_config.data_random_seed) {
Log::Fatal("Cannot change data_random_seed after constructed Dataset handle.");
}
if (new_param.count("max_bin") &&
new_config.max_bin != old_config.max_bin) {
Log::Fatal("Cannot change max_bin after constructed Dataset handle.");
}
if (new_param.count("max_bin_by_feature") &&
new_config.max_bin_by_feature != old_config.max_bin_by_feature) {
Log::Fatal(
"Cannot change max_bin_by_feature after constructed Dataset handle.");
}
if (new_param.count("bin_construct_sample_cnt") &&
new_config.bin_construct_sample_cnt !=
old_config.bin_construct_sample_cnt) {
Log::Fatal(
"Cannot change bin_construct_sample_cnt after constructed Dataset "
"handle.");
}
if (new_param.count("min_data_in_bin") &&
new_config.min_data_in_bin != old_config.min_data_in_bin) {
Log::Fatal(
"Cannot change min_data_in_bin after constructed Dataset handle.");
}
if (new_param.count("use_missing") &&
new_config.use_missing != old_config.use_missing) {
Log::Fatal("Cannot change use_missing after constructed Dataset handle.");
}
if (new_param.count("zero_as_missing") &&
new_config.zero_as_missing != old_config.zero_as_missing) {
Log::Fatal(
"Cannot change zero_as_missing after constructed Dataset handle.");
}
if (new_param.count("categorical_feature") &&
new_config.categorical_feature != old_config.categorical_feature) {
Log::Fatal(
"Cannot change categorical_feature after constructed Dataset "
"handle.");
}
if (new_param.count("feature_pre_filter") &&
new_config.feature_pre_filter != old_config.feature_pre_filter) {
Log::Fatal(
"Cannot change feature_pre_filter after constructed Dataset handle.");
}
if (new_param.count("is_enable_sparse") &&
new_config.is_enable_sparse != old_config.is_enable_sparse) {
Log::Fatal(
"Cannot change is_enable_sparse after constructed Dataset handle.");
}
if (new_param.count("pre_partition") &&
new_config.pre_partition != old_config.pre_partition) {
Log::Fatal(
"Cannot change pre_partition after constructed Dataset handle.");
}
if (new_param.count("enable_bundle") &&
new_config.enable_bundle != old_config.enable_bundle) {
Log::Fatal(
"Cannot change enable_bundle after constructed Dataset handle.");
}
if (new_param.count("header") && new_config.header != old_config.header) {
Log::Fatal("Cannot change header after constructed Dataset handle.");
}
if (new_param.count("two_round") &&
new_config.two_round != old_config.two_round) {
Log::Fatal("Cannot change two_round after constructed Dataset handle.");
}
if (new_param.count("label_column") &&
new_config.label_column != old_config.label_column) {
Log::Fatal(
"Cannot change label_column after constructed Dataset handle.");
}
if (new_param.count("weight_column") &&
new_config.weight_column != old_config.weight_column) {
Log::Fatal(
"Cannot change weight_column after constructed Dataset handle.");
}
if (new_param.count("group_column") &&
new_config.group_column != old_config.group_column) {
Log::Fatal(
"Cannot change group_column after constructed Dataset handle.");
}
if (new_param.count("ignore_column") &&
new_config.ignore_column != old_config.ignore_column) {
Log::Fatal(
"Cannot change ignore_column after constructed Dataset handle.");
}
if (new_param.count("forcedbins_filename")) {
Log::Fatal("Cannot change forced bins after constructed Dataset handle.");
}
if (new_param.count("min_data_in_leaf") &&
new_config.min_data_in_leaf < old_config.min_data_in_leaf &&
old_config.feature_pre_filter) {
Log::Fatal(
"Reducing `min_data_in_leaf` with `feature_pre_filter=true` may "
"cause unexpected behaviour "
"for features that were pre-filtered by the larger "
"`min_data_in_leaf`.\n"
"You need to set `feature_pre_filter=false` to dynamically change "
"the `min_data_in_leaf`.");
}
}
void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
auto param = Config::Str2Map(parameters);
......@@ -186,7 +294,10 @@ class Booster {
Log::Fatal("Cannot change metric during training");
}
CheckDatasetResetConfig(config_, param);
config_.Set(param);
if (config_.num_threads > 0) {
omp_set_num_threads(config_.num_threads);
}
......@@ -1028,19 +1139,19 @@ int LGBM_DatasetGetField(DatasetHandle handle,
} else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
*out_type = C_API_DTYPE_FLOAT64;
is_success = true;
} else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
*out_type = C_API_DTYPE_INT8;
is_success = true;
}
}
if (!is_success) { throw std::runtime_error("Field not found"); }
if (*out_ptr == nullptr) { *out_len = 0; }
API_END();
}
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
int LGBM_DatasetUpdateParamChecking(const char* old_parameters, const char* new_parameters) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->ResetConfig(parameters);
auto old_param = Config::Str2Map(old_parameters);
Config old_config;
old_config.Set(old_param);
auto new_param = Config::Str2Map(new_parameters);
Booster::CheckDatasetResetConfig(old_config, new_param);
API_END();
}
......
......@@ -324,7 +324,7 @@ namespace LightGBM {
}
void BinMapper::FindBin(double* values, int num_sample_values, size_t total_sample_cnt,
int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type,
int max_bin, int min_data_in_bin, int min_split_data, bool pre_filter, BinType bin_type,
bool use_missing, bool zero_as_missing,
const std::vector<double>& forced_upper_bounds) {
int na_cnt = 0;
......@@ -504,7 +504,7 @@ namespace LightGBM {
is_trivial_ = false;
}
// check useless bin
if (!is_trivial_ && NeedFilter(cnt_in_bin, static_cast<int>(total_sample_cnt), min_split_data, bin_type_)) {
if (!is_trivial_ && pre_filter && NeedFilter(cnt_in_bin, static_cast<int>(total_sample_cnt), min_split_data, bin_type_)) {
is_trivial_ = true;
}
......
......@@ -234,6 +234,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"is_enable_sparse",
"max_bin_by_feature",
"min_data_in_bin",
"feature_pre_filter",
"bin_construct_sample_cnt",
"histogram_pool_size",
"data_random_seed",
......@@ -460,6 +461,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0);
GetBool(params, "feature_pre_filter", &feature_pre_filter);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt >0);
......@@ -657,6 +660,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[is_enable_sparse: " << is_enable_sparse << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
str_buf << "[feature_pre_filter: " << feature_pre_filter << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
str_buf << "[data_random_seed: " << data_random_seed << "]\n";
......
......@@ -377,33 +377,6 @@ void Dataset::Construct(
last_group = group;
}
}
if (!io_config.monotone_constraints.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.monotone_constraints.size());
monotone_types_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
monotone_types_[inner_fidx] = io_config.monotone_constraints[i];
}
}
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
}
if (!io_config.feature_contri.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.feature_contri.size());
feature_penalty_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
feature_penalty_[inner_fidx] = std::max(0.0, io_config.feature_contri[i]);
}
}
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
}
if (!io_config.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.max_bin_by_feature.size());
CHECK(*(std::min_element(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end())) > 1);
......@@ -418,60 +391,6 @@ void Dataset::Construct(
zero_as_missing_ = io_config.zero_as_missing;
}
void Dataset::ResetConfig(const char* parameters) {
auto param = Config::Str2Map(parameters);
Config io_config;
io_config.Set(param);
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
}
if (param.count("max_bin_by_feature") && io_config.max_bin_by_feature != max_bin_by_feature_) {
Log::Warning("Cannot change max_bin_by_feature after constructed Dataset handle.");
}
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
}
if (param.count("min_data_in_bin") && io_config.min_data_in_bin != min_data_in_bin_) {
Log::Warning("Cannot change min_data_in_bin after constructed Dataset handle.");
}
if (param.count("use_missing") && io_config.use_missing != use_missing_) {
Log::Warning("Cannot change use_missing after constructed Dataset handle.");
}
if (param.count("zero_as_missing") && io_config.zero_as_missing != zero_as_missing_) {
Log::Warning("Cannot change zero_as_missing after constructed Dataset handle.");
}
if (param.count("forcedbins_filename")) {
Log::Warning("Cannot change forced bins after constructed Dataset handle.");
}
if (!io_config.monotone_constraints.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.monotone_constraints.size());
monotone_types_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
monotone_types_[inner_fidx] = io_config.monotone_constraints[i];
}
}
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
}
if (!io_config.feature_contri.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.feature_contri.size());
feature_penalty_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
feature_penalty_[inner_fidx] = std::max(0.0, io_config.feature_contri[i]);
}
}
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
}
}
void Dataset::FinishLoad() {
if (is_finish_load_) { return; }
if (num_groups_ > 0) {
......@@ -764,8 +683,6 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
group_bin_boundaries_ = dataset->group_bin_boundaries_;
group_feature_start_ = dataset->group_feature_start_;
group_feature_cnt_ = dataset->group_feature_cnt_;
monotone_types_ = dataset->monotone_types_;
feature_penalty_ = dataset->feature_penalty_;
forced_bin_bounds_ = dataset->forced_bin_bounds_;
feature_need_push_zeros_ = dataset->feature_need_push_zeros_;
}
......@@ -817,8 +734,6 @@ void Dataset::CreateValid(const Dataset* dataset) {
last_group = group;
}
}
monotone_types_ = dataset->monotone_types_;
feature_penalty_ = dataset->feature_penalty_;
forced_bin_bounds_ = dataset->forced_bin_bounds_;
}
......@@ -924,9 +839,6 @@ bool Dataset::GetDoubleField(const char* field_name, data_size_t* out_len, const
if (name == std::string("init_score")) {
*out_ptr = metadata_.init_score();
*out_len = static_cast<data_size_t>(metadata_.num_init_score());
} else if (name == std::string("feature_penalty")) {
*out_ptr = feature_penalty_.data();
*out_len = static_cast<data_size_t>(feature_penalty_.size());
} else {
return false;
}
......@@ -945,18 +857,6 @@ bool Dataset::GetIntField(const char* field_name, data_size_t* out_len, const in
return true;
}
bool Dataset::GetInt8Field(const char* field_name, data_size_t* out_len, const int8_t** out_ptr) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("monotone_constraints")) {
*out_ptr = monotone_types_.data();
*out_len = static_cast<data_size_t>(monotone_types_.size());
} else {
return false;
}
return true;
}
void Dataset::SaveBinaryFile(const char* bin_filename) {
if (bin_filename != nullptr
&& std::string(bin_filename) == data_filename_) {
......@@ -987,8 +887,8 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
// get size of header
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
+ sizeof(double) * num_features_ + sizeof(int32_t) * num_total_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_
+ sizeof(int32_t) * num_total_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
// size of feature names
for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int);
......@@ -1016,20 +916,6 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
writer->Write(group_bin_boundaries_.data(), sizeof(uint64_t) * (num_groups_ + 1));
writer->Write(group_feature_start_.data(), sizeof(int) * num_groups_);
writer->Write(group_feature_cnt_.data(), sizeof(int) * num_groups_);
if (monotone_types_.empty()) {
ArrayArgs<int8_t>::Assign(&monotone_types_, 0, num_features_);
}
writer->Write(monotone_types_.data(), sizeof(int8_t) * num_features_);
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
if (feature_penalty_.empty()) {
ArrayArgs<double>::Assign(&feature_penalty_, 1.0, num_features_);
}
writer->Write(feature_penalty_.data(), sizeof(double) * num_features_);
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
if (max_bin_by_feature_.empty()) {
ArrayArgs<int32_t>::Assign(&max_bin_by_feature_, -1, num_total_features_);
}
......@@ -1086,14 +972,6 @@ void Dataset::DumpTextFile(const char* text_filename) {
for (auto n : feature_names_) {
fprintf(file, "%s, ", n.c_str());
}
fprintf(file, "\nmonotone_constraints: ");
for (auto i : monotone_types_) {
fprintf(file, "%d, ", i);
}
fprintf(file, "\nfeature_penalty: ");
for (auto i : feature_penalty_) {
fprintf(file, "%lf, ", i);
}
fprintf(file, "\nmax_bin_by_feature: ");
for (auto i : max_bin_by_feature_) {
fprintf(file, "%d, ", i);
......@@ -1595,9 +1473,6 @@ void Dataset::AddFeaturesFrom(Dataset* other) {
group_bin_boundaries_.push_back(*i + bin_offset);
}
PushOffset(&group_feature_start_, other->group_feature_start_, num_features_);
PushClearIfEmpty(&monotone_types_, num_total_features_, other->monotone_types_, other->num_total_features_, (int8_t)0);
PushClearIfEmpty(&feature_penalty_, num_total_features_, other->feature_penalty_, other->num_total_features_, 1.0);
PushClearIfEmpty(&max_bin_by_feature_, num_total_features_, other->max_bin_by_feature_, other->num_total_features_, -1);
num_features_ += other->num_features_;
......
......@@ -389,50 +389,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
mem_ptr += sizeof(int) * (dataset->num_groups_);
if (!config_.monotone_constraints.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.monotone_constraints.size());
dataset->monotone_types_.resize(dataset->num_features_);
for (int i = 0; i < dataset->num_total_features_; ++i) {
int inner_fidx = dataset->InnerFeatureIndex(i);
if (inner_fidx >= 0) {
dataset->monotone_types_[inner_fidx] = config_.monotone_constraints[i];
}
}
} else {
const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
dataset->monotone_types_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
}
}
mem_ptr += sizeof(int8_t) * (dataset->num_features_);
if (ArrayArgs<int8_t>::CheckAllZero(dataset->monotone_types_)) {
dataset->monotone_types_.clear();
}
if (!config_.feature_contri.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.feature_contri.size());
dataset->feature_penalty_.resize(dataset->num_features_);
for (int i = 0; i < dataset->num_total_features_; ++i) {
int inner_fidx = dataset->InnerFeatureIndex(i);
if (inner_fidx >= 0) {
dataset->feature_penalty_[inner_fidx] = config_.feature_contri[i];
}
}
} else {
const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
dataset->feature_penalty_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
}
}
mem_ptr += sizeof(double) * (dataset->num_features_);
if (ArrayArgs<double>::CheckAll(dataset->feature_penalty_, 1)) {
dataset->feature_penalty_.clear();
}
if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
......@@ -617,13 +573,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt,
config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
bin_type, config_.use_missing, config_.zero_as_missing,
forced_bin_bounds[i]);
} else {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin_by_feature[i], config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing,
filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
config_.zero_as_missing, forced_bin_bounds[i]);
}
OMP_LOOP_EX_END();
......@@ -665,12 +621,12 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
forced_bin_bounds[i]);
} else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin_by_feature[start[rank] + i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
config_.zero_as_missing, forced_bin_bounds[i]);
}
OMP_LOOP_EX_END();
......@@ -943,12 +899,12 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
forced_bin_bounds[i]);
} else {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
config_.zero_as_missing, forced_bin_bounds[i]);
}
OMP_LOOP_EX_END();
......@@ -987,13 +943,13 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
forced_bin_bounds[i]);
} else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type,
config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
}
OMP_LOOP_EX_END();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment