envs.py 1.6 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from functools import lru_cache

import torch

DTYPE_MAP = {
    "BF16": torch.bfloat16,
    "FP16": torch.float16,
    "FP32": torch.float32,
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
    "torch.bfloat16": torch.bfloat16,
    "torch.float16": torch.float16,
    "torch.float32": torch.float32,
}


@lru_cache(maxsize=None)
def CHECK_PROFILING_DEBUG_LEVEL(target_level):
    current_level = int(os.getenv("PROFILING_DEBUG_LEVEL", "0"))
    return current_level >= target_level


@lru_cache(maxsize=None)
def GET_RUNNING_FLAG():
    RUNNING_FLAG = os.getenv("RUNNING_FLAG", "infer")
    return RUNNING_FLAG


@lru_cache(maxsize=None)
def GET_DTYPE():
    RUNNING_FLAG = os.getenv("DTYPE", "BF16")
    assert RUNNING_FLAG in ["BF16", "FP16"]
    return DTYPE_MAP[RUNNING_FLAG]


@lru_cache(maxsize=None)
def GET_SENSITIVE_DTYPE():
    RUNNING_FLAG = os.getenv("SENSITIVE_LAYER_DTYPE", "None")
    if RUNNING_FLAG == "None":
        return GET_DTYPE()
    return DTYPE_MAP[RUNNING_FLAG]


@lru_cache(maxsize=None)
def GET_RECORDER_MODE():
    RECORDER_MODE = int(os.getenv("RECORDER_MODE", "0"))
    return RECORDER_MODE


@lru_cache(maxsize=None)
def GET_USE_CHANNELS_LAST_3D():
    """
    Check if channels_last_3d memory format optimization is enabled.

    When enabled (USE_CHANNELS_LAST_3D=1), Conv3d weights are converted to
    channels_last_3d format, eliminating NCHW<->NHWC format conversion overhead
    in cuDNN (~9% faster).

    Returns:
        bool: True if optimization is enabled, False otherwise (default).
    """
    return os.getenv("USE_CHANNELS_LAST_3D", "0") == "1"