import sys as _sys from cupy._core import ndarray as _ndarray from cupyx.scipy.sparse._base import spmatrix as _spmatrix try: import scipy as _scipy _scipy_available = True except ImportError: _scipy_available = False _cupyx_scipy = _sys.modules[__name__] def get_array_module(*args): """Returns the array module for arguments. This function is used to implement CPU/GPU generic code. If at least one of the arguments is a :class:`cupy.ndarray` object, the :mod:`cupyx.scipy` module is returned. Args: args: Values to determine whether NumPy or CuPy should be used. Returns: module: :mod:`cupyx.scipy` or :mod:`scipy` is returned based on the types of the arguments. """ for arg in args: if isinstance(arg, (_ndarray, _spmatrix)): return _cupyx_scipy return _scipy