__init__.py 901 Bytes
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
__version__ = "1.0.0"

from flash_mla.flash_mla_interface import (
    get_mla_metadata,
    flash_mla_with_kvcache,
zhanghj2's avatar
zhanghj2 committed
6
    flash_mla_sparse_fwd,
zhanghj2's avatar
zhanghj2 committed
7
8
9
10
11
12
    get_mla_decoding_metadata_dense_fp8,
    flash_mla_with_kvcache_quantization,
    flash_mla_with_kvcache_q_nope_pe,
    flash_mla_with_kvcache_quantization_q_nope_pe,
    flash_mla_with_kvcache_fp8,
    flash_mla_with_kvcache_fp8_with_cat
Jiashi Li's avatar
Jiashi Li committed
13
)
14
15
16
17

__all__ = [
    "get_mla_metadata",
    "flash_mla_with_kvcache",
zhanghj2's avatar
zhanghj2 committed
18
    "flash_mla_sparse_fwd",
zhanghj2's avatar
zhanghj2 committed
19
20
21
22
23
24
    "get_mla_decoding_metadata_dense_fp8",
    "flash_mla_with_kvcache_quantization",
    "flash_mla_with_kvcache_q_nope_pe",
    "flash_mla_with_kvcache_quantization_q_nope_pe",
    "flash_mla_with_kvcache_fp8",
    "flash_mla_with_kvcache_fp8_with_cat"
25
]
zhanghj2's avatar
zhanghj2 committed
26
27
28
29
30

import os
FLASH_MLA_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# print(FLUX_ROOT_DIR)
os.environ["FLASH_MLA_ROOT_DIR"] = FLASH_MLA_ROOT_DIR + "/asm/"