common_utils.py 1.68 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# SPDX-License-Identifier: MIT
 
import os
import json
import functools
from typing import List

import torch
import triton
from triton.runtime.cache import default_cache_dir


def prev_power_of_2(x: int) -> int:
    out = triton.next_power_of_2(x)
    return out // 2 if out > x else out


STATIC_MAX_SEQ_LENS: List[int] = []
USE_RUNTIME_MAX_SEQ_LEN: bool = False


def autotune_max_seq_len(runtime_max_seq_len: int) -> int:
    global USE_RUNTIME_MAX_SEQ_LEN

    if USE_RUNTIME_MAX_SEQ_LEN:
        return prev_power_of_2(runtime_max_seq_len)
    else:
        if STATIC_MAX_SEQ_LENS == []:
            return 1
        for max_len in STATIC_MAX_SEQ_LENS:
            if max_len >= runtime_max_seq_len:
                return max_len
        return STATIC_MAX_SEQ_LENS[-1]


def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor:
    if x.stride(-1) == 1:
        return x
    return x.contiguous()


def get_triton_cache_dir():
    return os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()


file_cache = {}


def save_kernel_path(filename: str, config: dict, kernel_path: str):
    """
    config: kernel config
    """
    key = str(config)
    path_cache_dir = f"{get_triton_cache_dir()}/saved_kernel"
    os.makedirs(path_cache_dir, exist_ok=True)
    file_path = f"{path_cache_dir}/{filename}"
    data = file_cache[file_path] if file_path in file_cache else {}
    if key not in data:
        data[key] = kernel_path
        with open(file_path, "w") as f:
            json.dump(data, f, indent=4)
        file_cache[file_path] = data


@functools.lru_cache
def has_kernel_cache(path):
    return False if not path or not os.path.isdir(f'{get_triton_cache_dir()}/{path}') else True