__version__ = "1.0.0" from flash_mla.flash_mla_interface import ( get_mla_metadata, flash_mla_with_kvcache, flash_mla_sparse_fwd, get_mla_decoding_metadata_dense_fp8, flash_mla_with_kvcache_quantization, flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_quantization_q_nope_pe, flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8_with_cat ) __all__ = [ "get_mla_metadata", "flash_mla_with_kvcache", "flash_mla_sparse_fwd", "get_mla_decoding_metadata_dense_fp8", "flash_mla_with_kvcache_quantization", "flash_mla_with_kvcache_q_nope_pe", "flash_mla_with_kvcache_quantization_q_nope_pe", "flash_mla_with_kvcache_fp8", "flash_mla_with_kvcache_fp8_with_cat" ] import os FLASH_MLA_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # print(FLUX_ROOT_DIR) os.environ["FLASH_MLA_ROOT_DIR"] = FLASH_MLA_ROOT_DIR + "/asm/"