Unverified Commit 82886ba6 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] [R-package] Use the same address when updated label/weight/query (#2662)

* Update metadata.cpp

* add version for training set, for efficiently update label/weight/... during training.

* Update lgb.Booster.R
parent 350d56d5
...@@ -55,6 +55,7 @@ Booster <- R6::R6Class( ...@@ -55,6 +55,7 @@ Booster <- R6::R6Class(
# Create private booster information # Create private booster information
private$train_set <- train_set private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version
private$num_dataset <- 1L private$num_dataset <- 1L
private$init_predictor <- train_set$.__enclos_env__$private$predictor private$init_predictor <- train_set$.__enclos_env__$private$predictor
...@@ -207,6 +208,12 @@ Booster <- R6::R6Class( ...@@ -207,6 +208,12 @@ Booster <- R6::R6Class(
# Perform boosting update iteration # Perform boosting update iteration
update = function(train_set = NULL, fobj = NULL) { update = function(train_set = NULL, fobj = NULL) {
if (is.null(train_set)) {
if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
train_set <- private$train_set
}
}
# Check if training set is not null # Check if training set is not null
if (!is.null(train_set)) { if (!is.null(train_set)) {
...@@ -230,6 +237,7 @@ Booster <- R6::R6Class( ...@@ -230,6 +237,7 @@ Booster <- R6::R6Class(
# Store private train set # Store private train set
private$train_set <- train_set private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version
} }
...@@ -497,6 +505,7 @@ Booster <- R6::R6Class( ...@@ -497,6 +505,7 @@ Booster <- R6::R6Class(
eval_names = NULL, eval_names = NULL,
higher_better_inner_eval = NULL, higher_better_inner_eval = NULL,
set_objective_to_none = FALSE, set_objective_to_none = FALSE,
train_set_version = 0L,
# Predict data # Predict data
inner_predict = function(idx) { inner_predict = function(idx) {
......
...@@ -89,6 +89,7 @@ Dataset <- R6::R6Class( ...@@ -89,6 +89,7 @@ Dataset <- R6::R6Class(
private$free_raw_data <- free_raw_data private$free_raw_data <- free_raw_data
private$used_indices <- sort(used_indices, decreasing = FALSE) private$used_indices <- sort(used_indices, decreasing = FALSE)
private$info <- info private$info <- info
private$version <- 0L
}, },
...@@ -503,6 +504,8 @@ Dataset <- R6::R6Class( ...@@ -503,6 +504,8 @@ Dataset <- R6::R6Class(
, length(info) , length(info)
) )
private$version <- private$version + 1L
} }
} }
...@@ -638,6 +641,7 @@ Dataset <- R6::R6Class( ...@@ -638,6 +641,7 @@ Dataset <- R6::R6Class(
free_raw_data = TRUE, free_raw_data = TRUE,
used_indices = NULL, used_indices = NULL,
info = NULL, info = NULL,
version = 0L,
# Get handle # Get handle
get_handle = function() { get_handle = function() {
......
...@@ -771,6 +771,7 @@ class Dataset(object): ...@@ -771,6 +771,7 @@ class Dataset(object):
self.params_back_up = None self.params_back_up = None
self.feature_penalty = None self.feature_penalty = None
self.monotone_constraints = None self.monotone_constraints = None
self.version = 0
def __del__(self): def __del__(self):
try: try:
...@@ -1233,6 +1234,7 @@ class Dataset(object): ...@@ -1233,6 +1234,7 @@ class Dataset(object):
ptr_data, ptr_data,
ctypes.c_int(len(data)), ctypes.c_int(len(data)),
ctypes.c_int(type_data))) ctypes.c_int(type_data)))
self.version += 1
return self return self
def get_field(self, field_name): def get_field(self, field_name):
...@@ -1740,6 +1742,7 @@ class Booster(object): ...@@ -1740,6 +1742,7 @@ class Booster(object):
self.__is_predicted_cur_iter = [False] self.__is_predicted_cur_iter = [False]
self.__get_eval_info() self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical self.pandas_categorical = train_set.pandas_categorical
self.train_set_version = train_set.version
elif model_file is not None: elif model_file is not None:
# Prediction task # Prediction task
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
...@@ -2076,7 +2079,12 @@ class Booster(object): ...@@ -2076,7 +2079,12 @@ class Booster(object):
Whether the update was successfully finished. Whether the update was successfully finished.
""" """
# need reset training data # need reset training data
if train_set is not None and train_set is not self.train_set: if train_set is None and self.train_set_version != self.train_set.version:
train_set = self.train_set
is_the_same_train_set = False
else:
is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
if train_set is not None and not is_the_same_train_set:
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}' raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__)) .format(type(train_set).__name__))
...@@ -2088,6 +2096,7 @@ class Booster(object): ...@@ -2088,6 +2096,7 @@ class Booster(object):
self.handle, self.handle,
self.train_set.construct().handle)) self.train_set.construct().handle))
self.__inner_predict_buffer[0] = None self.__inner_predict_buffer[0] = None
self.train_set_version = self.train_set.version
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
if fobj is None: if fobj is None:
if self.__set_objective_to_none: if self.__set_objective_to_none:
......
...@@ -290,9 +290,9 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) { ...@@ -290,9 +290,9 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
if ((len % num_data_) != 0) { if ((len % num_data_) != 0) {
Log::Fatal("Initial score size doesn't match data size"); Log::Fatal("Initial score size doesn't match data size");
} }
if (!init_score_.empty()) { init_score_.clear(); } if (init_score_.empty()) { init_score_.resize(len); }
num_init_score_ = len; num_init_score_ = len;
init_score_ = std::vector<double>(len);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int64_t i = 0; i < num_init_score_; ++i) { for (int64_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = Common::AvoidInf(init_score[i]); init_score_[i] = Common::AvoidInf(init_score[i]);
...@@ -308,8 +308,8 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) { ...@@ -308,8 +308,8 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) {
if (num_data_ != len) { if (num_data_ != len) {
Log::Fatal("Length of label is not same with #data"); Log::Fatal("Length of label is not same with #data");
} }
if (!label_.empty()) { label_.clear(); } if (label_.empty()) { label_.resize(num_data_); }
label_ = std::vector<label_t>(num_data_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]); label_[i] = Common::AvoidInf(label[i]);
...@@ -327,9 +327,9 @@ void Metadata::SetWeights(const label_t* weights, data_size_t len) { ...@@ -327,9 +327,9 @@ void Metadata::SetWeights(const label_t* weights, data_size_t len) {
if (num_data_ != len) { if (num_data_ != len) {
Log::Fatal("Length of weights is not same with #data"); Log::Fatal("Length of weights is not same with #data");
} }
if (!weights_.empty()) { weights_.clear(); } if (weights_.empty()) { weights_.resize(num_data_); }
num_weights_ = num_data_; num_weights_ = num_data_;
weights_ = std::vector<label_t>(num_weights_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = Common::AvoidInf(weights[i]); weights_[i] = Common::AvoidInf(weights[i]);
...@@ -354,9 +354,8 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { ...@@ -354,9 +354,8 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
if (num_data_ != sum) { if (num_data_ != sum) {
Log::Fatal("Sum of query counts is not same with #data"); Log::Fatal("Sum of query counts is not same with #data");
} }
if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
num_queries_ = len; num_queries_ = len;
query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1); query_boundaries_.resize(num_queries_ + 1);
query_boundaries_[0] = 0; query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i]; query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment