Commit 87f2acb1 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] updated multiclass objective in sklearn (#1218)

* added comment to not forget

* updated LGBMClassifier according to new aliases
parent 1e61f24f
...@@ -183,7 +183,7 @@ def convert_from_sliced_object(data): ...@@ -183,7 +183,7 @@ def convert_from_sliced_object(data):
"""fix the memory of multi-dimensional sliced object""" """fix the memory of multi-dimensional sliced object"""
if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray): if data.base is not None and isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous: if not data.flags.c_contiguous:
warnings.warn("Use subset(sliced data) of np.ndarray is not recommended due to it will double the peak memory cost in LightGBM.") warnings.warn("Usage subset(sliced data) of np.ndarray is not recommended due to it will double the peak memory cost in LightGBM.")
return np.copy(data) return np.copy(data)
return data return data
...@@ -206,7 +206,7 @@ def c_float_array(data): ...@@ -206,7 +206,7 @@ def c_float_array(data):
.format(data.dtype)) .format(data.dtype))
else: else:
raise TypeError("Unknown type({})".format(type(data).__name__)) raise TypeError("Unknown type({})".format(type(data).__name__))
return (ptr_data, type_data, data) return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed
def c_int_array(data): def c_int_array(data):
...@@ -227,7 +227,7 @@ def c_int_array(data): ...@@ -227,7 +227,7 @@ def c_int_array(data):
.format(data.dtype)) .format(data.dtype))
else: else:
raise TypeError("Unknown type({})".format(type(data).__name__)) raise TypeError("Unknown type({})".format(type(data).__name__))
return (ptr_data, type_data, data) return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int',
......
...@@ -644,7 +644,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -644,7 +644,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
self._n_classes = len(self._classes) self._n_classes = len(self._classes)
if self._n_classes > 2: if self._n_classes > 2:
# Switch to using a multiclass objective in the underlying LGBM instance # Switch to using a multiclass objective in the underlying LGBM instance
if self._objective != "multiclassova" and not callable(self._objective): ova_aliases = ("multiclassova", "multiclass_ova", "ova", "ovr")
if self._objective not in ova_aliases and not callable(self._objective):
self._objective = "multiclass" self._objective = "multiclass"
if eval_metric == 'logloss' or eval_metric == 'binary_logloss': if eval_metric == 'logloss' or eval_metric == 'binary_logloss':
eval_metric = "multi_logloss" eval_metric = "multi_logloss"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment