Commit b6f65783 authored by Nikita Titov's avatar Nikita Titov Committed by Qiwei Ye
Browse files

[python] fix class_weight (#2199)

* fixed class_weight

* fixed lint

* added test

* hotfix
parent 7d03ced3
...@@ -10,7 +10,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, ...@@ -10,7 +10,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight, _LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame, DataTable) argc_, range_, zip_, string_type, DataFrame, DataTable)
from .engine import train from .engine import train
...@@ -320,6 +320,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -320,6 +320,8 @@ class LGBMModel(_LGBMModelBase):
self._other_params = {} self._other_params = {}
self._objective = objective self._objective = objective
self.class_weight = class_weight self.class_weight = class_weight
self._class_weight = None
self._class_map = None
self._n_features = None self._n_features = None
self._classes = None self._classes = None
self._n_classes = None self._n_classes = None
...@@ -529,8 +531,10 @@ class LGBMModel(_LGBMModelBase): ...@@ -529,8 +531,10 @@ class LGBMModel(_LGBMModelBase):
else: else:
_X, _y = X, y _X, _y = X, y
if self.class_weight is not None: if self._class_weight is None:
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y) self._class_weight = self.class_weight
if self._class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self._class_weight, y)
if sample_weight is None or len(sample_weight) == 0: if sample_weight is None or len(sample_weight) == 0:
sample_weight = class_sample_weight sample_weight = class_sample_weight
else: else:
...@@ -547,7 +551,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -547,7 +551,7 @@ class LGBMModel(_LGBMModelBase):
valid_sets = [] valid_sets = []
if eval_set is not None: if eval_set is not None:
def _get_meta_data(collection, i): def _get_meta_data(collection, name, i):
if collection is None: if collection is None:
return None return None
elif isinstance(collection, list): elif isinstance(collection, list):
...@@ -555,8 +559,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -555,8 +559,7 @@ class LGBMModel(_LGBMModelBase):
elif isinstance(collection, dict): elif isinstance(collection, dict):
return collection.get(i, None) return collection.get(i, None)
else: else:
raise TypeError('eval_sample_weight, eval_class_weight, eval_init_score, and eval_group ' raise TypeError('{} should be dict or list'.format(name))
'should be dict or list')
if isinstance(eval_set, tuple): if isinstance(eval_set, tuple):
eval_set = [eval_set] eval_set = [eval_set]
...@@ -565,16 +568,18 @@ class LGBMModel(_LGBMModelBase): ...@@ -565,16 +568,18 @@ class LGBMModel(_LGBMModelBase):
if valid_data[0] is X and valid_data[1] is y: if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set valid_set = train_set
else: else:
valid_weight = _get_meta_data(eval_sample_weight, i) valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i)
if _get_meta_data(eval_class_weight, i) is not None: valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i)
valid_class_sample_weight = _LGBMComputeSampleWeight(_get_meta_data(eval_class_weight, i), if valid_class_weight is not None:
valid_data[1]) if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0: if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight valid_weight = valid_class_sample_weight
else: else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight) valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, i) valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i)
valid_group = _get_meta_data(eval_group, i) valid_group = _get_meta_data(eval_group, 'eval_group', i)
valid_set = _construct_dataset(valid_data[0], valid_data[1], valid_set = _construct_dataset(valid_data[0], valid_data[1],
valid_weight, valid_init_score, valid_group, params) valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set) valid_sets.append(valid_set)
...@@ -750,6 +755,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -750,6 +755,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
_LGBMCheckClassificationTargets(y) _LGBMCheckClassificationTargets(y)
self._le = _LGBMLabelEncoder().fit(y) self._le = _LGBMLabelEncoder().fit(y)
_y = self._le.transform(y) _y = self._le.transform(y)
self._class_map = dict(zip_(self._le.classes_, self._le.transform(self._le.classes_)))
if isinstance(self.class_weight, dict):
self._class_weight = {self._class_map[k]: v for k, v in self.class_weight.items()}
self._classes = self._le.classes_ self._classes = self._le.classes_
self._n_classes = len(self._classes) self._n_classes = len(self._classes)
......
# coding: utf-8 # coding: utf-8
# pylint: skip-file # pylint: skip-file
import itertools
import math import math
import os import os
import unittest import unittest
...@@ -615,3 +616,37 @@ class TestSklearn(unittest.TestCase): ...@@ -615,3 +616,37 @@ class TestSklearn(unittest.TestCase):
'verbose': False, 'early_stopping_rounds': 5} 'verbose': False, 'early_stopping_rounds': 5}
gbm = lgb.LGBMRegressor(**params).fit(**params_fit) gbm = lgb.LGBMRegressor(**params).fit(**params_fit)
np.testing.assert_array_equal(gbm.evals_result_['training']['l2'], np.nan) np.testing.assert_array_equal(gbm.evals_result_['training']['l2'], np.nan)
def test_class_weight(self):
X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
y_train_str = y_train.astype('str')
y_test_str = y_test.astype('str')
gbm = lgb.LGBMClassifier(n_estimators=10, class_weight='balanced', silent=True)
gbm.fit(X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test), (X_test, y_test),
(X_test, y_test), (X_test, y_test)],
eval_class_weight=['balanced', None, 'balanced', {1: 10, 4: 20}, {5: 30, 2: 40}],
verbose=False)
for eval_set1, eval_set2 in itertools.combinations(gbm.evals_result_.keys(), 2):
for metric in gbm.evals_result_[eval_set1]:
np.testing.assert_raises(AssertionError,
np.testing.assert_allclose,
gbm.evals_result_[eval_set1][metric],
gbm.evals_result_[eval_set2][metric])
gbm_str = lgb.LGBMClassifier(n_estimators=10, class_weight='balanced', silent=True)
gbm_str.fit(X_train, y_train_str,
eval_set=[(X_train, y_train_str), (X_test, y_test_str),
(X_test, y_test_str), (X_test, y_test_str), (X_test, y_test_str)],
eval_class_weight=['balanced', None, 'balanced', {'1': 10, '4': 20}, {'5': 30, '2': 40}],
verbose=False)
for eval_set1, eval_set2 in itertools.combinations(gbm_str.evals_result_.keys(), 2):
for metric in gbm_str.evals_result_[eval_set1]:
np.testing.assert_raises(AssertionError,
np.testing.assert_allclose,
gbm_str.evals_result_[eval_set1][metric],
gbm_str.evals_result_[eval_set2][metric])
for eval_set in gbm.evals_result_:
for metric in gbm.evals_result_[eval_set]:
np.testing.assert_allclose(gbm.evals_result_[eval_set][metric],
gbm_str.evals_result_[eval_set][metric])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment