Unverified Commit 18dbd65e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] consolidate pandas-to-numpy conversion code (#6156)

parent e63e54ac
......@@ -758,6 +758,23 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')
def _pandas_to_numpy(
data: pd_DataFrame,
target_dtype: "np.typing.DTypeLike"
) -> np.ndarray:
_check_for_bad_pandas_dtypes(data.dtypes)
try:
# most common case (no nullable dtypes)
return data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
return data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
return data.to_numpy(dtype=target_dtype, na_value=np.nan)
def _data_from_pandas(
data: pd_DataFrame,
feature_name: _LGBM_FeatureNameConfiguration,
......@@ -790,22 +807,17 @@ def _data_from_pandas(
else: # use cat cols specified by user
categorical_feature = list(categorical_feature) # type: ignore[assignment]
# get numpy representation of the data
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
# so that the target dtype considers floats
df_dtypes.append(np.float32)
target_dtype = np.result_type(*df_dtypes)
try:
# most common case (no nullable dtypes)
data = data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
data = data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
data = data.to_numpy(dtype=target_dtype, na_value=np.nan)
return data, feature_name, categorical_feature, pandas_categorical
return (
_pandas_to_numpy(data, target_dtype=target_dtype),
feature_name,
categorical_feature,
pandas_categorical
)
def _dump_pandas_categorical(
......@@ -2805,18 +2817,7 @@ class Dataset:
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
try:
# most common case (no nullable dtypes)
label = label.to_numpy(dtype=np.float32, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
label = label.astype(np.float32, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
elif _is_pyarrow_array(label):
label_array = label
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment