Unverified Commit 2db0b25e authored by Malte Londschien's avatar Malte Londschien Committed by GitHub
Browse files

[python-package] Separately check whether `pyarrow` and `cffi` are installed (#6785)


Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent c9de57b0
......@@ -27,6 +27,7 @@ import numpy as np
import scipy.sparse
from .compat import (
CFFI_INSTALLED,
PANDAS_INSTALLED,
PYARROW_INSTALLED,
arrow_cffi,
......@@ -1706,8 +1707,8 @@ class _InnerPredictor:
predict_type: int,
) -> Tuple[np.ndarray, int]:
"""Predict for a PyArrow table."""
if not PYARROW_INSTALLED:
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")
if not (PYARROW_INSTALLED and CFFI_INSTALLED):
raise LightGBMError("Cannot predict from Arrow without 'pyarrow' and 'cffi' installed.")
# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
......@@ -2458,8 +2459,8 @@ class Dataset:
ref_dataset: Optional[_DatasetHandle],
) -> "Dataset":
"""Initialize data from a PyArrow table."""
if not PYARROW_INSTALLED:
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")
if not (PYARROW_INSTALLED and CFFI_INSTALLED):
raise LightGBMError("Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed.")
# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
......
......@@ -289,7 +289,6 @@ try:
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer
......@@ -316,19 +315,8 @@ except ImportError:
def __init__(self, *args: Any, **kwargs: Any):
pass
class arrow_cffi: # type: ignore
"""Dummy class for pyarrow.cffi.ffi."""
CData = None
addressof = None
cast = None
new = None
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute."""
"""Dummy class for pyarrow.compute module."""
all = None
equal = None
......@@ -338,6 +326,24 @@ except ImportError:
arrow_is_integer = None
arrow_is_floating = None
"""cffi"""
try:
from pyarrow.cffi import ffi as arrow_cffi
CFFI_INSTALLED = True
except ImportError:
CFFI_INSTALLED = False
class arrow_cffi: # type: ignore
"""Dummy class for pyarrow.cffi.ffi."""
CData = None
def __init__(self, *args: Any, **kwargs: Any):
pass
"""cpu_count()"""
try:
from joblib import cpu_count
......
import numpy as np
import pytest
import lightgbm
@pytest.fixture(scope="function")
def missing_module_cffi(monkeypatch):
"""Mock 'cffi' not being importable"""
monkeypatch.setattr(lightgbm.compat, "CFFI_INSTALLED", False)
monkeypatch.setattr(lightgbm.basic, "CFFI_INSTALLED", False)
@pytest.fixture(scope="function")
def rng():
......
......@@ -454,3 +454,32 @@ def test_arrow_feature_name_manual():
)
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
assert booster.feature_name() == ["c", "d"]
def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi):
with pytest.raises(
lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed."
):
lgb.Dataset(
generate_dummy_arrow_table(),
label=pa.array([0, 1, 0, 0, 1]),
params=dummy_dataset_params(),
).construct()
def test_predicting_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi):
data = generate_random_arrow_table(num_columns=3, num_datapoints=1_000, seed=42)
labels = generate_random_arrow_array(num_datapoints=data.shape[0], seed=42)
bst = lgb.train(
params={"num_leaves": 7, "verbose": -1},
train_set=lgb.Dataset(
data.to_pandas(),
label=labels.to_pandas(),
),
num_boost_round=2,
)
with pytest.raises(
lgb.basic.LightGBMError, match="Cannot predict from Arrow without 'pyarrow' and 'cffi' installed."
):
bst.predict(data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment