kernels_utils.py 693 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from ..utils import get_logger
from .import_utils import is_kernels_available


logger = get_logger(__name__)


_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"


def _get_fa3_from_hub():
    if not is_kernels_available():
        return None
    else:
        from kernels import get_kernel

        try:
            # TODO: temporary revision for now. Remove when merged upstream into `main`.
            flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
            return flash_attn_3_hub
        except Exception as e:
            logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
            raise