envs.py 4.54 KB
Newer Older
mashun1's avatar
wan2.1  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import torch
import diffusers
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from packaging import version

from xfuser.logger import init_logger

logger = init_logger(__name__)

if TYPE_CHECKING:
    MASTER_ADDR: str = ""
    MASTER_PORT: Optional[int] = None
    CUDA_HOME: Optional[str] = None
    LOCAL_RANK: int = 0
    CUDA_VISIBLE_DEVICES: Optional[str] = None
    XDIT_LOGGING_LEVEL: str = "INFO"
    CUDA_VERSION: version.Version
    TORCH_VERSION: version.Version


environment_variables: Dict[str, Callable[[], Any]] = {
    # ================== Runtime Env Vars ==================
    # used in distributed environment to determine the master address
    "MASTER_ADDR": lambda: os.getenv("MASTER_ADDR", ""),
    # used in distributed environment to manually set the communication port
    "MASTER_PORT": lambda: (
        int(os.getenv("MASTER_PORT", "0")) if "MASTER_PORT" in os.environ else None
    ),
    # path to cudatoolkit home directory, under which should be bin, include,
    # and lib directories.
    "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None),
    # local rank of the process in the distributed setting, used to determine
    # the GPU device id
    "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
    # used to control the visible devices in the distributed setting
    "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
    # this is used for configuring the default logging level
    "XDIT_LOGGING_LEVEL": lambda: os.getenv("XDIT_LOGGING_LEVEL", "INFO"),
}

variables: Dict[str, Callable[[], Any]] = {
    # ================== Other Vars ==================
    # used in version checking
    # "CUDA_VERSION": lambda: version.parse(torch.version.cuda),
    "CUDA_VERSION": "gfx928",
    "TORCH_VERSION": lambda: version.parse(
        version.parse(torch.__version__).base_version
    ),
}


class PackagesEnvChecker:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(PackagesEnvChecker, cls).__new__(cls)
            cls._instance.initialize()
        return cls._instance

    def initialize(self):
        self.packages_info = {
            "has_flash_attn": self.check_flash_attn(),
            "has_long_ctx_attn": self.check_long_ctx_attn(),
            "diffusers_version": self.check_diffusers_version(),
        }

    def check_flash_attn(self):
        try:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            gpu_name = torch.cuda.get_device_name(device)
            if "Turing" in gpu_name or "Tesla" in gpu_name or "T4" in gpu_name:
                return False
            else:
                from flash_attn import flash_attn_func
                from flash_attn import __version__

                if __version__ < "2.6.0":
                    raise ImportError(f"install flash_attn >= 2.6.0")
                return True
        except ImportError:
            logger.warning(
                f'Flash Attention library "flash_attn" not found, '
                f"using pytorch attention implementation"
            )
            return False

    def check_long_ctx_attn(self):
        try:
            from yunchang import (
                set_seq_parallel_pg,
                ring_flash_attn_func,
                UlyssesAttention,
                LongContextAttention,
                LongContextAttentionQKVPacked,
            )

            return True
        except ImportError:
            logger.warning(
                f'Ring Flash Attention library "yunchang" not found, '
                f"using pytorch attention implementation"
            )
            return False

    def check_diffusers_version(self):
        if version.parse(
            version.parse(diffusers.__version__).base_version
        ) < version.parse("0.30.0"):
            raise RuntimeError(
                f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
                f"please upgrade to version > 0.30.0"
            )
        return version.parse(version.parse(diffusers.__version__).base_version)

    def get_packages_info(self):
        return self.packages_info


PACKAGES_CHECKER = PackagesEnvChecker()


def __getattr__(name):
    # lazy evaluation of environment variables
    if name in environment_variables:
        return environment_variables[name]()
    if name in variables:
        return variables[name]()
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__():
    return list(environment_variables.keys())