from ..utils import get_logger from .import_utils import is_kernels_available logger = get_logger(__name__) _DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" def _get_fa3_from_hub(): if not is_kernels_available(): return None else: from kernels import get_kernel try: # TODO: temporary revision for now. Remove when merged upstream into `main`. flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") return flash_attn_3_hub except Exception as e: logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") raise