"tests/vscode:/vscode.git/clone" did not exist on "b7ccdaf066ad3eebc92616c87b0aeea6acc043f4"
Unverified Commit 6cbb3586 authored by Roman Shaptala's avatar Roman Shaptala Committed by GitHub
Browse files

[python] Faster categorical column names selection (#4787)

* Faster categorical column names selection (#1)

* Faster categorical column names selection

Change slow and redundant dataframe query by select_dtypes into a dataframe.dtypes list comprehension

* Update compat with CategoricalDtype

* sort imports

* import CategoricalDtype from pandas.api.types

* add categorical import try/except
parent 3b6ebd79
......@@ -17,7 +17,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Un
import numpy as np
import scipy.sparse
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series
from .compat import (PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
from .libpath import find_lib_path
ZERO_THRESHOLD = 1e-35
......@@ -567,7 +568,7 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
raise ValueError('Input data must be 2 dimensional and non empty.')
if feature_name == 'auto' or feature_name is None:
data = data.rename(columns=str)
cat_cols = list(data.select_dtypes(include=['category']).columns)
cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered]
if pandas_categorical is None: # train dataset
pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
......
......@@ -7,6 +7,10 @@ try:
from pandas import Series as pd_Series
from pandas import concat
from pandas.api.types import is_sparse as is_dtype_sparse
try:
from pandas import CategoricalDtype as pd_CategoricalDtype
except ImportError:
from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
PANDAS_INSTALLED = True
except ImportError:
PANDAS_INSTALLED = False
......@@ -23,6 +27,12 @@ except ImportError:
def __init__(self, *args, **kwargs):
pass
class pd_CategoricalDtype:
"""Dummy class for pandas.CategoricalDtype."""
def __init__(self, *args, **kwargs):
pass
concat = None
is_dtype_sparse = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment