misc.py 11.1 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
import gc
import os
chenych's avatar
chenych committed
20
21
import socket
from typing import TYPE_CHECKING, Any, Literal, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23

import torch
luopl's avatar
luopl committed
24
import torch.distributed as dist
chenych's avatar
chenych committed
25
import transformers.dynamic_module_utils
chenych's avatar
chenych committed
26
from huggingface_hub.utils import WeakFileLock
chenych's avatar
chenych committed
27
28
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
30
31
32
33
34
35
36
37
from transformers.utils import (
    is_torch_bf16_gpu_available,
    is_torch_cuda_available,
    is_torch_mps_available,
    is_torch_npu_available,
    is_torch_xpu_available,
)
from transformers.utils.versions import require_version

luopl's avatar
luopl committed
38
from . import logging
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39
40
41
42


_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
chenych's avatar
chenych committed
43
    _is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
44
45
46
47
48
except Exception:
    _is_bf16_available = False


if TYPE_CHECKING:
chenych's avatar
chenych committed
49
    from numpy.typing import NDArray
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
50

chenych's avatar
chenych committed
51
    from ..hparams import ModelArguments
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
52
53


luopl's avatar
luopl committed
54
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
55
56
57


class AverageMeter:
chenych's avatar
chenych committed
58
    r"""Compute and store the average and current value."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


luopl's avatar
luopl committed
76
def check_version(requirement: str, mandatory: bool = False) -> None:
chenych's avatar
chenych committed
77
    r"""Optionally check the package version."""
chenych's avatar
chenych committed
78
    if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
luopl's avatar
luopl committed
79
80
81
        logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
        return

chenych's avatar
chenych committed
82
83
84
85
86
    if "gptmodel" in requirement or "autoawq" in requirement:
        pip_command = f"pip install {requirement} --no-build-isolation"
    else:
        pip_command = f"pip install {requirement}"

luopl's avatar
luopl committed
87
    if mandatory:
chenych's avatar
chenych committed
88
        hint = f"To fix: run `{pip_command}`."
luopl's avatar
luopl committed
89
    else:
chenych's avatar
chenych committed
90
        hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
luopl's avatar
luopl committed
91
92
93
94

    require_version(requirement, hint)


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
95
def check_dependencies() -> None:
chenych's avatar
chenych committed
96
    r"""Check the version of the required packages."""
chenych's avatar
chenych committed
97
    check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
chenych's avatar
chenych committed
98
    check_version("datasets>=2.16.0,<=3.6.0")
chenych's avatar
chenych committed
99
    check_version("accelerate>=1.3.0,<=1.7.0")
chenych's avatar
chenych committed
100
    check_version("peft>=0.14.0,<=0.15.2")
luopl's avatar
luopl committed
101
102
103
    check_version("trl>=0.8.6,<=0.9.6")


chenych's avatar
chenych committed
104
105
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
    r"""Calculate effective tokens per second."""
luopl's avatar
luopl committed
106
107
108
109
110
111
112
113
114
    effective_token_num = 0
    for data in dataset:
        if stage == "sft":
            effective_token_num += len(data["input_ids"])
        elif stage == "rm":
            effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])

    result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
    return result / dist.get_world_size() if dist.is_initialized() else result
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
115
116


chenych's avatar
chenych committed
117
118
def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
    r"""Return the number of trainable parameters and number of all parameters in the model."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
119
120
121
122
123
124
125
    trainable_params, all_param = 0, 0
    for param in model.parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

chenych's avatar
chenych committed
126
        # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        if param.__class__.__name__ == "Params4bit":
            if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
                num_bytes = param.quant_storage.itemsize
            elif hasattr(param, "element_size"):  # for older pytorch version
                num_bytes = param.element_size()
            else:
                num_bytes = 1

            num_params = num_params * 2 * num_bytes

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params

    return trainable_params, all_param


chenych's avatar
chenych committed
144
def get_current_device() -> "torch.device":
chenych's avatar
chenych committed
145
    r"""Get the current available device."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
146
    if is_torch_xpu_available():
chenych's avatar
chenych committed
147
        device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
148
    elif is_torch_npu_available():
chenych's avatar
chenych committed
149
        device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
150
    elif is_torch_mps_available():
chenych's avatar
chenych committed
151
        device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
152
    elif is_torch_cuda_available():
chenych's avatar
chenych committed
153
        device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
154
155
156
157
158
159
160
    else:
        device = "cpu"

    return torch.device(device)


def get_device_count() -> int:
chenych's avatar
chenych committed
161
    r"""Get the number of available devices."""
chenych's avatar
chenych committed
162
163
164
165
    if is_torch_xpu_available():
        return torch.xpu.device_count()
    elif is_torch_npu_available():
        return torch.npu.device_count()
chenych's avatar
chenych committed
166
167
    elif is_torch_mps_available():
        return torch.mps.device_count()
chenych's avatar
chenych committed
168
169
170
    elif is_torch_cuda_available():
        return torch.cuda.device_count()
    else:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
171
172
173
174
        return 0


def get_logits_processor() -> "LogitsProcessorList":
chenych's avatar
chenych committed
175
    r"""Get logits processor that removes NaN and Inf logits."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
176
177
178
179
180
    logits_processor = LogitsProcessorList()
    logits_processor.append(InfNanRemoveLogitsProcessor())
    return logits_processor


chenych's avatar
chenych committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def get_current_memory() -> tuple[int, int]:
    r"""Get the available and total memory for the current device (in Bytes)."""
    if is_torch_xpu_available():
        return torch.xpu.mem_get_info()
    elif is_torch_npu_available():
        return torch.npu.mem_get_info()
    elif is_torch_mps_available():
        return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory()
    elif is_torch_cuda_available():
        return torch.cuda.mem_get_info()
    else:
        return 0, -1


chenych's avatar
chenych committed
195
def get_peak_memory() -> tuple[int, int]:
chenych's avatar
chenych committed
196
    r"""Get the peak memory usage (allocated, reserved) for the current device (in Bytes)."""
chenych's avatar
chenych committed
197
    if is_torch_xpu_available():
chenych's avatar
chenych committed
198
        return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
chenych's avatar
chenych committed
199
200
201
202
    elif is_torch_npu_available():
        return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
    elif is_torch_mps_available():
        return torch.mps.current_allocated_memory(), -1
luopl's avatar
luopl committed
203
204
205
    elif is_torch_cuda_available():
        return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
    else:
chenych's avatar
chenych committed
206
        return 0, -1
luopl's avatar
luopl committed
207
208


chenych's avatar
chenych committed
209
def has_tokenized_data(path: "os.PathLike") -> bool:
chenych's avatar
chenych committed
210
    r"""Check if the path has a tokenized dataset."""
chenych's avatar
chenych committed
211
212
213
214
    return os.path.isdir(path) and len(os.listdir(path)) > 0


def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
chenych's avatar
chenych committed
215
    r"""Infer the optimal dtype according to the model_dtype and device compatibility."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
216
217
218
219
220
221
222
223
    if _is_bf16_available and model_dtype == torch.bfloat16:
        return torch.bfloat16
    elif _is_fp16_available:
        return torch.float16
    else:
        return torch.float32


chenych's avatar
chenych committed
224
225
226
227
228
def is_accelerator_available() -> bool:
    r"""Check if the accelerator is available."""
    return (
        is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()
    )
chenych's avatar
chenych committed
229
230


chenych's avatar
chenych committed
231
def is_env_enabled(env_var: str, default: str = "0") -> bool:
chenych's avatar
chenych committed
232
    r"""Check if the environment variable is enabled."""
chenych's avatar
chenych committed
233
234
235
    return os.getenv(env_var, default).lower() in ["true", "y", "1"]


chenych's avatar
chenych committed
236
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
chenych's avatar
chenych committed
237
    r"""Cast a torch tensor or a numpy array to a numpy array."""
chenych's avatar
chenych committed
238
239
240
241
242
243
244
245
246
247
248
    if isinstance(inputs, torch.Tensor):
        inputs = inputs.cpu()
        if inputs.dtype == torch.bfloat16:  # numpy does not support bfloat16 until 1.21.4
            inputs = inputs.to(torch.float32)

        inputs = inputs.numpy()

    return inputs


def skip_check_imports() -> None:
chenych's avatar
chenych committed
249
    r"""Avoid flash attention import error in custom model files."""
chenych's avatar
chenych committed
250
    if not is_env_enabled("FORCE_CHECK_IMPORTS"):
chenych's avatar
chenych committed
251
        transformers.dynamic_module_utils.check_imports = get_relative_imports
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
252
253
254


def torch_gc() -> None:
chenych's avatar
chenych committed
255
    r"""Collect the device memory."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
256
    gc.collect()
chenych's avatar
chenych committed
257
258
259
260
261
262
263
    if is_torch_xpu_available():
        torch.xpu.empty_cache()
    elif is_torch_npu_available():
        torch.npu.empty_cache()
    elif is_torch_mps_available():
        torch.mps.empty_cache()
    elif is_torch_cuda_available():
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
264
265
266
        torch.cuda.empty_cache()


luopl's avatar
luopl committed
267
268
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
    if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
269
270
        return model_args.model_name_or_path

luopl's avatar
luopl committed
271
    if use_modelscope():
chenych's avatar
chenych committed
272
        check_version("modelscope>=1.14.0", mandatory=True)
luopl's avatar
luopl committed
273
        from modelscope import snapshot_download  # type: ignore
chenych's avatar
chenych committed
274
275
276
277
278
        from modelscope.hub.api import HubApi  # type: ignore

        if model_args.ms_hub_token:
            api = HubApi()
            api.login(model_args.ms_hub_token)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
279
280

        revision = "master" if model_args.model_revision == "main" else model_args.model_revision
chenych's avatar
chenych committed
281
282
283
284
285
286
287
288
        with WeakFileLock(os.path.abspath(os.path.expanduser("~/.cache/llamafactory/modelscope.lock"))):
            model_path = snapshot_download(
                model_args.model_name_or_path,
                revision=revision,
                cache_dir=model_args.cache_dir,
            )

        return model_path
luopl's avatar
luopl committed
289
290

    if use_openmind():
luopl's avatar
luopl committed
291
        check_version("openmind>=0.8.0", mandatory=True)
luopl's avatar
luopl committed
292
293
        from openmind.utils.hub import snapshot_download  # type: ignore

chenych's avatar
chenych committed
294
295
296
297
298
299
300
301
        with WeakFileLock(os.path.abspath(os.path.expanduser("~/.cache/llamafactory/openmind.lock"))):
            model_path = snapshot_download(
                model_args.model_name_or_path,
                revision=model_args.model_revision,
                cache_dir=model_args.cache_dir,
            )

        return model_path
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
302
303
304


def use_modelscope() -> bool:
chenych's avatar
chenych committed
305
    return is_env_enabled("USE_MODELSCOPE_HUB")
luopl's avatar
luopl committed
306
307
308


def use_openmind() -> bool:
chenych's avatar
chenych committed
309
    return is_env_enabled("USE_OPENMIND_HUB")
luopl's avatar
luopl committed
310
311


luopl's avatar
luopl committed
312
def use_ray() -> bool:
chenych's avatar
chenych committed
313
    return is_env_enabled("USE_RAY")
chenych's avatar
chenych committed
314
315
316


def find_available_port() -> int:
chenych's avatar
chenych committed
317
    r"""Find an available port on the local machine."""
chenych's avatar
chenych committed
318
319
320
321
322
323
324
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()
    return port


chenych's avatar
chenych committed
325
326
def fix_proxy(ipv6_enabled: bool = False) -> None:
    r"""Fix proxy settings for gradio ui."""
chenych's avatar
chenych committed
327
328
    os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
    if ipv6_enabled:
chenych's avatar
chenych committed
329
330
        os.environ.pop("http_proxy", None)
        os.environ.pop("HTTP_PROXY", None)