envs.py 1.01 KB
Newer Older
1
import os
helloyongyang's avatar
helloyongyang committed
2
from functools import lru_cache
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

DTYPE_MAP = {
    "BF16": torch.bfloat16,
    "FP16": torch.float16,
    "FP32": torch.float32,
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
    "torch.bfloat16": torch.bfloat16,
    "torch.float16": torch.float16,
    "torch.float32": torch.float32,
}

18

helloyongyang's avatar
helloyongyang committed
19
@lru_cache(maxsize=None)
20
21
22
def CHECK_PROFILING_DEBUG_LEVEL(target_level):
    current_level = int(os.getenv("PROFILING_DEBUG_LEVEL", "0"))
    return current_level >= target_level
23

helloyongyang's avatar
helloyongyang committed
24

25
26
27
28
@lru_cache(maxsize=None)
def GET_RUNNING_FLAG():
    RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer")
    return RUNNING_FLAG
gushiqiao's avatar
gushiqiao committed
29
30
31
32


@lru_cache(maxsize=None)
def GET_DTYPE():
33
34
35
36
37
38
39
    RUNNING_FLAG = os.getenv("DTYPE", "BF16")
    assert RUNNING_FLAG in ["BF16", "FP16"]
    return DTYPE_MAP[RUNNING_FLAG]


@lru_cache(maxsize=None)
def GET_SENSITIVE_DTYPE():
helloyongyang's avatar
helloyongyang committed
40
41
    RUNNING_FLAG = os.getenv("SENSITIVE_LAYER_DTYPE", "None")
    if RUNNING_FLAG == "None":
42
43
        return GET_DTYPE()
    return DTYPE_MAP[RUNNING_FLAG]