Unverified Commit 3f5b3135 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] avoid to set all weight to 1. (#1152)

* avoid to set all weight to 1.

* fix valid_class_weight
parent 42710827
......@@ -1105,6 +1105,8 @@ class Dataset(object):
weight : list, numpy array or None
Weight to be set for each data point.
"""
if weight is not None and np.all(weight == 1):
weight = None
self.weight = weight
if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
......
......@@ -418,11 +418,12 @@ class LGBMModel(_LGBMModelBase):
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight)
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)
if self.class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)
self._n_features = X.shape[1]
......@@ -452,11 +453,12 @@ class LGBMModel(_LGBMModelBase):
else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i)
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
if get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = get_meta_data(eval_init_score, i)
valid_group = get_meta_data(eval_group, i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_weight, valid_init_score, valid_group, params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment