Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
b6f65783
Commit
b6f65783
authored
Jun 04, 2019
by
Nikita Titov
Committed by
Qiwei Ye
Jun 04, 2019
Browse files
[python] fix class_weight (#2199)
* fixed class_weight * fixed lint * added test * hotfix
parent
7d03ced3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
12 deletions
+55
-12
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+20
-12
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+35
-0
No files found.
python-package/lightgbm/sklearn.py
View file @
b6f65783
...
@@ -10,7 +10,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
...
@@ -10,7 +10,7 @@ from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError
,
_LGBMLabelEncoder
,
_LGBMModelBase
,
LGBMNotFittedError
,
_LGBMLabelEncoder
,
_LGBMModelBase
,
_LGBMRegressorBase
,
_LGBMCheckXY
,
_LGBMCheckArray
,
_LGBMCheckConsistentLength
,
_LGBMRegressorBase
,
_LGBMCheckXY
,
_LGBMCheckArray
,
_LGBMCheckConsistentLength
,
_LGBMAssertAllFinite
,
_LGBMCheckClassificationTargets
,
_LGBMComputeSampleWeight
,
_LGBMAssertAllFinite
,
_LGBMCheckClassificationTargets
,
_LGBMComputeSampleWeight
,
argc_
,
range_
,
string_type
,
DataFrame
,
DataTable
)
argc_
,
range_
,
zip_
,
string_type
,
DataFrame
,
DataTable
)
from
.engine
import
train
from
.engine
import
train
...
@@ -320,6 +320,8 @@ class LGBMModel(_LGBMModelBase):
...
@@ -320,6 +320,8 @@ class LGBMModel(_LGBMModelBase):
self
.
_other_params
=
{}
self
.
_other_params
=
{}
self
.
_objective
=
objective
self
.
_objective
=
objective
self
.
class_weight
=
class_weight
self
.
class_weight
=
class_weight
self
.
_class_weight
=
None
self
.
_class_map
=
None
self
.
_n_features
=
None
self
.
_n_features
=
None
self
.
_classes
=
None
self
.
_classes
=
None
self
.
_n_classes
=
None
self
.
_n_classes
=
None
...
@@ -529,8 +531,10 @@ class LGBMModel(_LGBMModelBase):
...
@@ -529,8 +531,10 @@ class LGBMModel(_LGBMModelBase):
else
:
else
:
_X
,
_y
=
X
,
y
_X
,
_y
=
X
,
y
if
self
.
class_weight
is
not
None
:
if
self
.
_class_weight
is
None
:
class_sample_weight
=
_LGBMComputeSampleWeight
(
self
.
class_weight
,
y
)
self
.
_class_weight
=
self
.
class_weight
if
self
.
_class_weight
is
not
None
:
class_sample_weight
=
_LGBMComputeSampleWeight
(
self
.
_class_weight
,
y
)
if
sample_weight
is
None
or
len
(
sample_weight
)
==
0
:
if
sample_weight
is
None
or
len
(
sample_weight
)
==
0
:
sample_weight
=
class_sample_weight
sample_weight
=
class_sample_weight
else
:
else
:
...
@@ -547,7 +551,7 @@ class LGBMModel(_LGBMModelBase):
...
@@ -547,7 +551,7 @@ class LGBMModel(_LGBMModelBase):
valid_sets
=
[]
valid_sets
=
[]
if
eval_set
is
not
None
:
if
eval_set
is
not
None
:
def
_get_meta_data
(
collection
,
i
):
def
_get_meta_data
(
collection
,
name
,
i
):
if
collection
is
None
:
if
collection
is
None
:
return
None
return
None
elif
isinstance
(
collection
,
list
):
elif
isinstance
(
collection
,
list
):
...
@@ -555,8 +559,7 @@ class LGBMModel(_LGBMModelBase):
...
@@ -555,8 +559,7 @@ class LGBMModel(_LGBMModelBase):
elif
isinstance
(
collection
,
dict
):
elif
isinstance
(
collection
,
dict
):
return
collection
.
get
(
i
,
None
)
return
collection
.
get
(
i
,
None
)
else
:
else
:
raise
TypeError
(
'eval_sample_weight, eval_class_weight, eval_init_score, and eval_group '
raise
TypeError
(
'{} should be dict or list'
.
format
(
name
))
'should be dict or list'
)
if
isinstance
(
eval_set
,
tuple
):
if
isinstance
(
eval_set
,
tuple
):
eval_set
=
[
eval_set
]
eval_set
=
[
eval_set
]
...
@@ -565,16 +568,18 @@ class LGBMModel(_LGBMModelBase):
...
@@ -565,16 +568,18 @@ class LGBMModel(_LGBMModelBase):
if
valid_data
[
0
]
is
X
and
valid_data
[
1
]
is
y
:
if
valid_data
[
0
]
is
X
and
valid_data
[
1
]
is
y
:
valid_set
=
train_set
valid_set
=
train_set
else
:
else
:
valid_weight
=
_get_meta_data
(
eval_sample_weight
,
i
)
valid_weight
=
_get_meta_data
(
eval_sample_weight
,
'eval_sample_weight'
,
i
)
if
_get_meta_data
(
eval_class_weight
,
i
)
is
not
None
:
valid_class_weight
=
_get_meta_data
(
eval_class_weight
,
'eval_class_weight'
,
i
)
valid_class_sample_weight
=
_LGBMComputeSampleWeight
(
_get_meta_data
(
eval_class_weight
,
i
),
if
valid_class_weight
is
not
None
:
valid_data
[
1
])
if
isinstance
(
valid_class_weight
,
dict
)
and
self
.
_class_map
is
not
None
:
valid_class_weight
=
{
self
.
_class_map
[
k
]:
v
for
k
,
v
in
valid_class_weight
.
items
()}
valid_class_sample_weight
=
_LGBMComputeSampleWeight
(
valid_class_weight
,
valid_data
[
1
])
if
valid_weight
is
None
or
len
(
valid_weight
)
==
0
:
if
valid_weight
is
None
or
len
(
valid_weight
)
==
0
:
valid_weight
=
valid_class_sample_weight
valid_weight
=
valid_class_sample_weight
else
:
else
:
valid_weight
=
np
.
multiply
(
valid_weight
,
valid_class_sample_weight
)
valid_weight
=
np
.
multiply
(
valid_weight
,
valid_class_sample_weight
)
valid_init_score
=
_get_meta_data
(
eval_init_score
,
i
)
valid_init_score
=
_get_meta_data
(
eval_init_score
,
'eval_init_score'
,
i
)
valid_group
=
_get_meta_data
(
eval_group
,
i
)
valid_group
=
_get_meta_data
(
eval_group
,
'eval_group'
,
i
)
valid_set
=
_construct_dataset
(
valid_data
[
0
],
valid_data
[
1
],
valid_set
=
_construct_dataset
(
valid_data
[
0
],
valid_data
[
1
],
valid_weight
,
valid_init_score
,
valid_group
,
params
)
valid_weight
,
valid_init_score
,
valid_group
,
params
)
valid_sets
.
append
(
valid_set
)
valid_sets
.
append
(
valid_set
)
...
@@ -750,6 +755,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
...
@@ -750,6 +755,9 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
_LGBMCheckClassificationTargets
(
y
)
_LGBMCheckClassificationTargets
(
y
)
self
.
_le
=
_LGBMLabelEncoder
().
fit
(
y
)
self
.
_le
=
_LGBMLabelEncoder
().
fit
(
y
)
_y
=
self
.
_le
.
transform
(
y
)
_y
=
self
.
_le
.
transform
(
y
)
self
.
_class_map
=
dict
(
zip_
(
self
.
_le
.
classes_
,
self
.
_le
.
transform
(
self
.
_le
.
classes_
)))
if
isinstance
(
self
.
class_weight
,
dict
):
self
.
_class_weight
=
{
self
.
_class_map
[
k
]:
v
for
k
,
v
in
self
.
class_weight
.
items
()}
self
.
_classes
=
self
.
_le
.
classes_
self
.
_classes
=
self
.
_le
.
classes_
self
.
_n_classes
=
len
(
self
.
_classes
)
self
.
_n_classes
=
len
(
self
.
_classes
)
...
...
tests/python_package_test/test_sklearn.py
View file @
b6f65783
# coding: utf-8
# coding: utf-8
# pylint: skip-file
# pylint: skip-file
import
itertools
import
math
import
math
import
os
import
os
import
unittest
import
unittest
...
@@ -615,3 +616,37 @@ class TestSklearn(unittest.TestCase):
...
@@ -615,3 +616,37 @@ class TestSklearn(unittest.TestCase):
'verbose'
:
False
,
'early_stopping_rounds'
:
5
}
'verbose'
:
False
,
'early_stopping_rounds'
:
5
}
gbm
=
lgb
.
LGBMRegressor
(
**
params
).
fit
(
**
params_fit
)
gbm
=
lgb
.
LGBMRegressor
(
**
params
).
fit
(
**
params_fit
)
np
.
testing
.
assert_array_equal
(
gbm
.
evals_result_
[
'training'
][
'l2'
],
np
.
nan
)
np
.
testing
.
assert_array_equal
(
gbm
.
evals_result_
[
'training'
][
'l2'
],
np
.
nan
)
def
test_class_weight
(
self
):
X
,
y
=
load_digits
(
10
,
True
)
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
X
,
y
,
test_size
=
0.2
,
random_state
=
42
)
y_train_str
=
y_train
.
astype
(
'str'
)
y_test_str
=
y_test
.
astype
(
'str'
)
gbm
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
class_weight
=
'balanced'
,
silent
=
True
)
gbm
.
fit
(
X_train
,
y_train
,
eval_set
=
[(
X_train
,
y_train
),
(
X_test
,
y_test
),
(
X_test
,
y_test
),
(
X_test
,
y_test
),
(
X_test
,
y_test
)],
eval_class_weight
=
[
'balanced'
,
None
,
'balanced'
,
{
1
:
10
,
4
:
20
},
{
5
:
30
,
2
:
40
}],
verbose
=
False
)
for
eval_set1
,
eval_set2
in
itertools
.
combinations
(
gbm
.
evals_result_
.
keys
(),
2
):
for
metric
in
gbm
.
evals_result_
[
eval_set1
]:
np
.
testing
.
assert_raises
(
AssertionError
,
np
.
testing
.
assert_allclose
,
gbm
.
evals_result_
[
eval_set1
][
metric
],
gbm
.
evals_result_
[
eval_set2
][
metric
])
gbm_str
=
lgb
.
LGBMClassifier
(
n_estimators
=
10
,
class_weight
=
'balanced'
,
silent
=
True
)
gbm_str
.
fit
(
X_train
,
y_train_str
,
eval_set
=
[(
X_train
,
y_train_str
),
(
X_test
,
y_test_str
),
(
X_test
,
y_test_str
),
(
X_test
,
y_test_str
),
(
X_test
,
y_test_str
)],
eval_class_weight
=
[
'balanced'
,
None
,
'balanced'
,
{
'1'
:
10
,
'4'
:
20
},
{
'5'
:
30
,
'2'
:
40
}],
verbose
=
False
)
for
eval_set1
,
eval_set2
in
itertools
.
combinations
(
gbm_str
.
evals_result_
.
keys
(),
2
):
for
metric
in
gbm_str
.
evals_result_
[
eval_set1
]:
np
.
testing
.
assert_raises
(
AssertionError
,
np
.
testing
.
assert_allclose
,
gbm_str
.
evals_result_
[
eval_set1
][
metric
],
gbm_str
.
evals_result_
[
eval_set2
][
metric
])
for
eval_set
in
gbm
.
evals_result_
:
for
metric
in
gbm
.
evals_result_
[
eval_set
]:
np
.
testing
.
assert_allclose
(
gbm
.
evals_result_
[
eval_set
][
metric
],
gbm_str
.
evals_result_
[
eval_set
][
metric
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment