"vscode:/vscode.git/clone" did not exist on "da0eb2bb07c83504058261128a8f8fbaf28ad2f7"
Unverified Commit 3f5b3135 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] avoid to set all weight to 1. (#1152)

* avoid to set all weight to 1.

* fix valid_class_weight
parent 42710827
...@@ -1105,6 +1105,8 @@ class Dataset(object): ...@@ -1105,6 +1105,8 @@ class Dataset(object):
weight : list, numpy array or None weight : list, numpy array or None
Weight to be set for each data point. Weight to be set for each data point.
""" """
if weight is not None and np.all(weight == 1):
weight = None
self.weight = weight self.weight = weight
if self.handle is not None and weight is not None: if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight') weight = list_to_1d_numpy(weight, name='weight')
......
...@@ -418,6 +418,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -418,6 +418,7 @@ class LGBMModel(_LGBMModelBase):
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight) _LGBMCheckConsistentLength(X, y, sample_weight)
if self.class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y) class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
if sample_weight is None or len(sample_weight) == 0: if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight sample_weight = class_sample_weight
...@@ -452,6 +453,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -452,6 +453,7 @@ class LGBMModel(_LGBMModelBase):
else: else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list') raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group should be dict or list')
valid_weight = get_meta_data(eval_sample_weight, i) valid_weight = get_meta_data(eval_sample_weight, i)
if get_meta_data(eval_class_weight, i) is not None:
valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1]) valid_class_sample_weight = _LGBMComputeSampleWeight(get_meta_data(eval_class_weight, i), valid_data[1])
if valid_weight is None or len(valid_weight) == 0: if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight valid_weight = valid_class_sample_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment