envs.py 1.18 KB
Newer Older
1
import os
helloyongyang's avatar
helloyongyang committed
2
from functools import lru_cache
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

DTYPE_MAP = {
    "BF16": torch.bfloat16,
    "FP16": torch.float16,
    "FP32": torch.float32,
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
    "torch.bfloat16": torch.bfloat16,
    "torch.float16": torch.float16,
    "torch.float32": torch.float32,
}

18

helloyongyang's avatar
helloyongyang committed
19
20
21
22
@lru_cache(maxsize=None)
def CHECK_ENABLE_PROFILING_DEBUG():
    ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
    return ENABLE_PROFILING_DEBUG
23

helloyongyang's avatar
helloyongyang committed
24
25
26
27
28

@lru_cache(maxsize=None)
def CHECK_ENABLE_GRAPH_MODE():
    ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
    return ENABLE_GRAPH_MODE
29
30
31
32
33
34


@lru_cache(maxsize=None)
def GET_RUNNING_FLAG():
    RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer")
    return RUNNING_FLAG
gushiqiao's avatar
gushiqiao committed
35
36
37
38


@lru_cache(maxsize=None)
def GET_DTYPE():
39
40
41
42
43
44
45
    RUNNING_FLAG = os.getenv("DTYPE", "BF16")
    assert RUNNING_FLAG in ["BF16", "FP16"]
    return DTYPE_MAP[RUNNING_FLAG]


@lru_cache(maxsize=None)
def GET_SENSITIVE_DTYPE():
helloyongyang's avatar
helloyongyang committed
46
47
    RUNNING_FLAG = os.getenv("SENSITIVE_LAYER_DTYPE", "None")
    if RUNNING_FLAG == "None":
48
49
        return GET_DTYPE()
    return DTYPE_MAP[RUNNING_FLAG]