"src/vscode:/vscode.git/clone" did not exist on "5172b5337863de04dd4080dae4cb694f7bfacbb9"
Unverified Commit c043be1d authored by Henry Sorsky's avatar Henry Sorsky Committed by GitHub
Browse files

[python-package] Better column dtype logging when column has "bad dtype" (#5065)



* better logging of column datatypes

* update to checking function

* fix typo

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarJosé Morales <jmoralz92@gmail.com>

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarJosé Morales <jmoralz92@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent f834cab4
......@@ -183,8 +183,7 @@ def list_to_1d_numpy(data, dtype=np.float32, name='list'):
elif is_1d_list(data):
return np.array(data, dtype=dtype, copy=False)
elif isinstance(data, pd_Series):
if _get_bad_pandas_dtypes([data.dtypes]):
raise ValueError('Series.dtypes must be int, float or bool')
_check_for_bad_pandas_dtypes(data.to_frame().dtypes)
return np.array(data, dtype=dtype, copy=False) # SparseArray should be supported as well
else:
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
......@@ -217,8 +216,7 @@ def _data_to_2d_numpy(data: Any, dtype: type = np.float32, name: str = 'list') -
if _is_2d_list(data):
return np.array(data, dtype=dtype)
if isinstance(data, pd_DataFrame):
if _get_bad_pandas_dtypes(data.dtypes):
raise ValueError('DataFrame.dtypes must be int, float or bool')
_check_for_bad_pandas_dtypes(data.dtypes)
return cast_numpy_array_to_dtype(data.values, dtype)
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
"It should be list of lists, numpy 2-D array or pandas DataFrame")
......@@ -500,7 +498,7 @@ def c_int_array(data):
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed
def _get_bad_pandas_dtypes(dtypes):
def _check_for_bad_pandas_dtypes(pandas_dtypes_series):
float128 = getattr(np, 'float128', type(None))
def is_allowed_numpy_dtype(dtype):
......@@ -509,7 +507,14 @@ def _get_bad_pandas_dtypes(dtypes):
and not issubclass(dtype, (np.timedelta64, float128))
)
return [i for i, dtype in enumerate(dtypes) if not is_allowed_numpy_dtype(dtype.type)]
bad_pandas_dtypes = [
f'{column_name}: {pandas_dtype}'
for column_name, pandas_dtype in pandas_dtypes_series.iteritems()
if not is_allowed_numpy_dtype(pandas_dtype.type)
]
if bad_pandas_dtypes:
raise ValueError('pandas dtypes must be int, float or bool.\n'
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')
def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
......@@ -540,12 +545,7 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
categorical_feature = list(categorical_feature)
if feature_name == 'auto':
feature_name = list(data.columns)
bad_indices = _get_bad_pandas_dtypes(data.dtypes)
if bad_indices:
bad_index_cols_str = ', '.join(data.columns[bad_indices])
raise ValueError("DataFrame.dtypes for data must be int, float or bool.\n"
"Did not expect the data types in the following fields: "
f"{bad_index_cols_str}")
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
target_dtype = np.find_common_type(df_dtypes, [])
......@@ -562,8 +562,7 @@ def _label_from_pandas(label):
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
if _get_bad_pandas_dtypes(label.dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
_check_for_bad_pandas_dtypes(label.dtypes)
label = np.ravel(label.values.astype(np.float32, copy=False))
return label
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment