Unverified Commit 3f5b3135 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] avoid to set all weight to 1. (#1152)

* avoid to set all weight to 1.

* fix valid_class_weight
parent 42710827
...@@ -1105,6 +1105,8 @@ class Dataset(object): ...@@ -1105,6 +1105,8 @@ class Dataset(object):
weight : list, numpy array or None weight : list, numpy array or None
Weight to be set for each data point. Weight to be set for each data point.
""" """
if weight is not None and np.all(weight == 1):
weight = None
self.weight = weight self.weight = weight
if self.handle is not None and weight is not None: if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight') weight = list_to_1d_numpy(weight, name='weight')
......
...@@ -418,11 +418,12 @@ class LGBMModel(_LGBMModelBase): ...@@ -418,11 +418,12 @@ class LGBMModel(_LGBMModelBase):
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight) _LGBMCheckConsistentLength(X, y, sample_weight)
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y) if self.class_weight is not None:
if sample_weight is None or len(sample_weight) == 0: class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
sample_weight = class_sample_weight if sample_weight is None or len(sample_weight) == 0:
else: sample_weight = class_sample_weight
sample_weight = np.multiply(sample_weight, class_sample_weight) else:
sample_weight = np.multiply(sample_weight, class_sample_weight)
self._n_features = X.shape[1] self._n_features = X.shape[1]
...@@ -452,11 +453,12 @@ class LGBMModel(_LGBMModelBase): ...@@ -452,11 +453,12 @@ class LGBMModel(_LGBMModelBase):
else: else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list') raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i) valid_weight = get_meta_data(eval_sample_weight, i)
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1]) if get_meta_data(eval_class_weight, i) is not None:
if valid_weight is None or len(valid_weight) == 0: valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
valid_weight = valid_class_sample_weight if valid_weight is None or len(valid_weight) == 0:
else: valid_weight = valid_class_sample_weight
valid_weight = np.multiply(valid_weight, valid_class_sample_weight) else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = get_meta_data(eval_init_score, i) valid_init_score = get_meta_data(eval_init_score, i)
valid_group = get_meta_data(eval_group, i) valid_group = get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params) valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment