env.py 6.44 KB
Newer Older
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import platform
16
import subprocess
17
18
19
20
21
from argparse import ArgumentParser

import huggingface_hub

from .. import __version__ as version
22
23
24
25
26
27
28
29
30
31
32
33
from ..utils import (
    is_accelerate_available,
    is_bitsandbytes_available,
    is_flax_available,
    is_google_colab,
    is_notebook,
    is_peft_available,
    is_safetensors_available,
    is_torch_available,
    is_transformers_available,
    is_xformers_available,
)
34
from ..utils.testing_utils import get_python_version
35
36
37
38
39
40
41
42
43
from . import BaseDiffusersCLICommand


def info_command_factory(_):
    return EnvironmentCommand()


class EnvironmentCommand(BaseDiffusersCLICommand):
    @staticmethod
44
    def register_subcommand(parser: ArgumentParser) -> None:
45
46
47
        download_parser = parser.add_parser("env")
        download_parser.set_defaults(func=info_command_factory)

48
    def run(self) -> dict:
49
50
        hub_version = huggingface_hub.__version__

51
52
53
54
55
56
        safetensors_version = "not installed"
        if is_safetensors_available():
            import safetensors

            safetensors_version = safetensors.__version__

57
58
59
60
61
62
63
64
        pt_version = "not installed"
        pt_cuda_available = "NA"
        if is_torch_available():
            import torch

            pt_version = torch.__version__
            pt_cuda_available = torch.cuda.is_available()

65
66
67
68
69
70
71
72
73
74
75
76
77
78
        flax_version = "not installed"
        jax_version = "not installed"
        jaxlib_version = "not installed"
        jax_backend = "NA"
        if is_flax_available():
            import flax
            import jax
            import jaxlib

            flax_version = flax.__version__
            jax_version = jax.__version__
            jaxlib_version = jaxlib.__version__
            jax_backend = jax.lib.xla_bridge.get_backend().platform

79
        transformers_version = "not installed"
80
        if is_transformers_available():
81
82
83
84
            import transformers

            transformers_version = transformers.__version__

85
86
87
88
        accelerate_version = "not installed"
        if is_accelerate_available():
            import accelerate

89
            accelerate_version = accelerate.__version__
90

91
92
93
94
95
96
97
98
99
100
101
102
        peft_version = "not installed"
        if is_peft_available():
            import peft

            peft_version = peft.__version__

        bitsandbytes_version = "not installed"
        if is_bitsandbytes_available():
            import bitsandbytes

            bitsandbytes_version = bitsandbytes.__version__

103
104
105
106
107
108
        xformers_version = "not installed"
        if is_xformers_available():
            import xformers

            xformers_version = xformers.__version__

109
110
111
112
113
        if get_python_version() >= (3, 10):
            platform_info = f"{platform.freedesktop_os_release().get('PRETTY_NAME', None)} - {platform.platform()}"
        else:
            platform_info = platform.platform()

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        is_notebook_str = "Yes" if is_notebook() else "No"

        is_google_colab_str = "Yes" if is_google_colab() else "No"

        accelerator = "NA"
        if platform.system() in {"Linux", "Windows"}:
            try:
                sp = subprocess.Popen(
                    ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                )
                out_str, _ = sp.communicate()
                out_str = out_str.decode("utf-8")

                if len(out_str) > 0:
                    accelerator = out_str.strip() + " VRAM"
            except FileNotFoundError:
                pass
        elif platform.system() == "Darwin":  # Mac OS
            try:
                sp = subprocess.Popen(
                    ["system_profiler", "SPDisplaysDataType"],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                )
                out_str, _ = sp.communicate()
                out_str = out_str.decode("utf-8")

                start = out_str.find("Chipset Model:")
                if start != -1:
                    start += len("Chipset Model:")
                    end = out_str.find("\n", start)
                    accelerator = out_str[start:end].strip()

                    start = out_str.find("VRAM (Total):")
                    if start != -1:
                        start += len("VRAM (Total):")
                        end = out_str.find("\n", start)
                        accelerator += " VRAM: " + out_str[start:end].strip()
            except FileNotFoundError:
                pass
        else:
            print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")

159
        info = {
160
            "🤗 Diffusers version": version,
161
            "Platform": platform_info,
162
163
            "Running on a notebook?": is_notebook_str,
            "Running on Google Colab?": is_google_colab_str,
164
165
            "Python version": platform.python_version(),
            "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
166
167
168
            "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
            "Jax version": jax_version,
            "JaxLib version": jaxlib_version,
169
170
            "Huggingface_hub version": hub_version,
            "Transformers version": transformers_version,
171
            "Accelerate version": accelerate_version,
172
173
174
            "PEFT version": peft_version,
            "Bitsandbytes version": bitsandbytes_version,
            "Safetensors version": safetensors_version,
175
            "xFormers version": xformers_version,
176
            "Accelerator": accelerator,
177
178
179
180
181
182
183
184
185
186
            "Using GPU in script?": "<fill in>",
            "Using distributed or parallel set-up in script?": "<fill in>",
        }

        print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
        print(self.format_dict(info))

        return info

    @staticmethod
187
    def format_dict(d: dict) -> str:
188
        return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"