parser.py 19 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

luopl's avatar
luopl committed
18
import json
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
19
20
import os
import sys
luopl's avatar
luopl committed
21
22
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
23
24
25

import torch
import transformers
luopl's avatar
luopl committed
26
27
import yaml
from transformers import HfArgumentParser
chenych's avatar
chenych committed
28
from transformers.integrations import is_deepspeed_zero3_enabled
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
from transformers.trainer_utils import get_last_checkpoint
chenych's avatar
chenych committed
30
from transformers.training_args import ParallelMode
luopl's avatar
luopl committed
31
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32

luopl's avatar
luopl committed
33
from ..extras import logging
chenych's avatar
chenych committed
34
from ..extras.constants import CHECKPOINT_NAMES
chenych's avatar
chenych committed
35
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
36
37
38
39
40
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
luopl's avatar
luopl committed
41
from .training_args import RayArguments, TrainingArguments
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
42
43


luopl's avatar
luopl committed
44
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
45
46
47
48

check_dependencies()


luopl's avatar
luopl committed
49
50
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
51
52
53
54
55
56
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]


luopl's avatar
luopl committed
57
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
chenych's avatar
chenych committed
58
59
60
    r"""
    Gets arguments from the command line or a config file.
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
61
    if args is not None:
luopl's avatar
luopl committed
62
        return args
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
63

luopl's avatar
luopl committed
64
    if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
luopl's avatar
luopl committed
65
66
67
68
69
        return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        return json.loads(Path(sys.argv[1]).absolute().read_text())
    else:
        return sys.argv[1:]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71


luopl's avatar
luopl committed
72
73
74
75
76
77
def _parse_args(
    parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
    args = read_args(args)
    if isinstance(args, dict):
        return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
78

luopl's avatar
luopl committed
79
80
81
    (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)

    if unknown_args and not allow_extra_keys:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
82
        print(parser.format_help())
luopl's avatar
luopl committed
83
84
        print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
        raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
85

chenych's avatar
chenych committed
86
    return tuple(parsed_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
87
88


luopl's avatar
luopl committed
89
def _set_transformers_logging() -> None:
chenych's avatar
chenych committed
90
91
92
93
    if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
94
95


chenych's avatar
chenych committed
96
97
98
99
100
def _verify_model_args(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    finetuning_args: "FinetuningArguments",
) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
101
102
103
104
105
106
107
    if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
        raise ValueError("Adapter is only valid for the LoRA method.")

    if model_args.quantization_bit is not None:
        if finetuning_args.finetuning_type != "lora":
            raise ValueError("Quantization is only compatible with the LoRA method.")

chenych's avatar
chenych committed
108
109
110
111
112
113
        if finetuning_args.pissa_init:
            raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")

        if model_args.resize_vocab:
            raise ValueError("Cannot resize embedding layers of a quantized model.")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
114
115
116
117
118
119
        if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
            raise ValueError("Cannot create new adapter upon a quantized model.")

        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
            raise ValueError("Quantized model only accepts a single adapter. Merge them first.")

chenych's avatar
chenych committed
120
    if data_args.template == "yi" and model_args.use_fast_tokenizer:
luopl's avatar
luopl committed
121
        logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
chenych's avatar
chenych committed
122
123
        model_args.use_fast_tokenizer = False

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
124
125
126
127

def _check_extra_dependencies(
    model_args: "ModelArguments",
    finetuning_args: "FinetuningArguments",
luopl's avatar
luopl committed
128
    training_args: Optional["TrainingArguments"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
129
130
) -> None:
    if model_args.use_unsloth:
luopl's avatar
luopl committed
131
        check_version("unsloth", mandatory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
132

luopl's avatar
luopl committed
133
    if model_args.enable_liger_kernel:
luopl's avatar
luopl committed
134
        check_version("liger-kernel", mandatory=True)
luopl's avatar
luopl committed
135

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
136
    if model_args.mixture_of_depths is not None:
luopl's avatar
luopl committed
137
        check_version("mixture-of-depth>=1.1.6", mandatory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
138
139

    if model_args.infer_backend == "vllm":
chenych's avatar
chenych committed
140
        check_version("vllm>=0.4.3,<=0.7.3")
luopl's avatar
luopl committed
141
        check_version("vllm", mandatory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142
143

    if finetuning_args.use_galore:
luopl's avatar
luopl committed
144
145
146
147
        check_version("galore_torch", mandatory=True)

    if finetuning_args.use_apollo:
        check_version("apollo_torch", mandatory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
148
149

    if finetuning_args.use_badam:
luopl's avatar
luopl committed
150
        check_version("badam>=1.2.1", mandatory=True)
chenych's avatar
chenych committed
151
152

    if finetuning_args.use_adam_mini:
luopl's avatar
luopl committed
153
        check_version("adam-mini", mandatory=True)
chenych's avatar
chenych committed
154
155

    if finetuning_args.plot_loss:
luopl's avatar
luopl committed
156
        check_version("matplotlib", mandatory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
157
158

    if training_args is not None and training_args.predict_with_generate:
luopl's avatar
luopl committed
159
160
161
        check_version("jieba", mandatory=True)
        check_version("nltk", mandatory=True)
        check_version("rouge_chinese", mandatory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
162
163


luopl's avatar
luopl committed
164
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
165
    parser = HfArgumentParser(_TRAIN_ARGS)
chenych's avatar
chenych committed
166
167
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
168
169


luopl's avatar
luopl committed
170
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
171
    parser = HfArgumentParser(_INFER_ARGS)
chenych's avatar
chenych committed
172
173
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
174
175


luopl's avatar
luopl committed
176
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
177
    parser = HfArgumentParser(_EVAL_ARGS)
chenych's avatar
chenych committed
178
179
    allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
    return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
180
181


luopl's avatar
luopl committed
182
183
184
185
186
187
188
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
    parser = HfArgumentParser(RayArguments)
    (ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
    return ray_args


def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
189
190
191
192
193
194
195
    model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)

    # Setup logging
    if training_args.should_log:
        _set_transformers_logging()

    # Check arguments
chenych's avatar
chenych committed
196
197
198
199
200
201
202
203
204
    if finetuning_args.stage != "sft":
        if training_args.predict_with_generate:
            raise ValueError("`predict_with_generate` cannot be set as True except SFT.")

        if data_args.neat_packing:
            raise ValueError("`neat_packing` cannot be set as True except SFT.")

        if data_args.train_on_prompt or data_args.mask_history:
            raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
205
206
207
208
209
210
211

    if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
        raise ValueError("Please enable `predict_with_generate` to save model predictions.")

    if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
        raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")

chenych's avatar
chenych committed
212
213
214
    if finetuning_args.stage == "ppo":
        if not training_args.do_train:
            raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
215

chenych's avatar
chenych committed
216
217
        if model_args.shift_attn:
            raise ValueError("PPO training is incompatible with S^2-Attn.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
218

chenych's avatar
chenych committed
219
220
        if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
            raise ValueError("Unsloth does not support lora reward model.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
221

chenych's avatar
chenych committed
222
223
224
225
226
227
228
229
        if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
            raise ValueError("PPO only accepts wandb or tensorboard logger.")

    if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
        raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")

    if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
        raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
230
231
232
233

    if training_args.max_steps == -1 and data_args.streaming:
        raise ValueError("Please specify `max_steps` in streaming mode.")

chenych's avatar
chenych committed
234
235
236
237
238
239
240
241
    if training_args.do_train and data_args.dataset is None:
        raise ValueError("Please specify dataset for training.")

    if (training_args.do_eval or training_args.do_predict) and (
        data_args.eval_dataset is None and data_args.val_size < 1e-6
    ):
        raise ValueError("Please specify dataset for evaluation.")

luopl's avatar
luopl committed
242
243
244
245
246
247
    if training_args.predict_with_generate:
        if is_deepspeed_zero3_enabled():
            raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")

        if data_args.eval_dataset is None:
            raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
chenych's avatar
chenych committed
248

luopl's avatar
luopl committed
249
250
        if finetuning_args.compute_accuracy:
            raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
251
252
253
254

    if training_args.do_train and model_args.quantization_device_map == "auto":
        raise ValueError("Cannot use device map for quantized models in training.")

chenych's avatar
chenych committed
255
256
    if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
        raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
257
258

    if finetuning_args.pure_bf16:
luopl's avatar
luopl committed
259
        if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
260
261
            raise ValueError("This device does not support `pure_bf16`.")

chenych's avatar
chenych committed
262
263
        if is_deepspeed_zero3_enabled():
            raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
264

luopl's avatar
luopl committed
265
266
267
268
269
270
    if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
        if finetuning_args.use_galore and finetuning_args.galore_layerwise:
            raise ValueError("Distributed training does not support layer-wise GaLore.")

        if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
            raise ValueError("Distributed training does not support layer-wise APOLLO.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
271

luopl's avatar
luopl committed
272
273
274
275
276
        if finetuning_args.use_badam:
            if finetuning_args.badam_mode == "ratio":
                raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
            elif not is_deepspeed_zero3_enabled():
                raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
277

luopl's avatar
luopl committed
278
279
    if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
        raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
280
281
282
283

    if model_args.infer_backend == "vllm":
        raise ValueError("vLLM backend is only available for API, CLI and Web.")

chenych's avatar
chenych committed
284
285
286
287
    if model_args.use_unsloth and is_deepspeed_zero3_enabled():
        raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")

    if data_args.neat_packing and not data_args.packing:
luopl's avatar
luopl committed
288
        logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
chenych's avatar
chenych committed
289
290
291
        data_args.packing = True

    _verify_model_args(model_args, data_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
292
293
294
295
296
    _check_extra_dependencies(model_args, finetuning_args, training_args)

    if (
        training_args.do_train
        and finetuning_args.finetuning_type == "lora"
chenych's avatar
chenych committed
297
        and model_args.quantization_bit is None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
298
299
300
        and model_args.resize_vocab
        and finetuning_args.additional_target is None
    ):
luopl's avatar
luopl committed
301
302
303
        logger.warning_rank0(
            "Remember to add embedding layers to `additional_target` to make the added tokens trainable."
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
304
305

    if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
luopl's avatar
luopl committed
306
        logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
307
308

    if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
luopl's avatar
luopl committed
309
        logger.warning_rank0("We recommend enable mixed precision training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
310

luopl's avatar
luopl committed
311
312
313
314
315
    if (
        training_args.do_train
        and (finetuning_args.use_galore or finetuning_args.use_apollo)
        and not finetuning_args.pure_bf16
    ):
luopl's avatar
luopl committed
316
        logger.warning_rank0(
luopl's avatar
luopl committed
317
            "Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
luopl's avatar
luopl committed
318
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
319
320

    if (not training_args.do_train) and model_args.quantization_bit is not None:
luopl's avatar
luopl committed
321
        logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
322
323

    if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
luopl's avatar
luopl committed
324
        logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
325
326
327

    # Post-process training arguments
    if (
chenych's avatar
chenych committed
328
        training_args.parallel_mode == ParallelMode.DISTRIBUTED
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
329
330
331
        and training_args.ddp_find_unused_parameters is None
        and finetuning_args.finetuning_type == "lora"
    ):
luopl's avatar
luopl committed
332
        logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
333
334
335
336
337
        training_args.ddp_find_unused_parameters = False

    if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
        can_resume_from_checkpoint = False
        if training_args.resume_from_checkpoint is not None:
luopl's avatar
luopl committed
338
            logger.warning_rank0("Cannot resume from checkpoint in current stage.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
339
340
341
342
343
344
345
346
347
348
349
350
            training_args.resume_from_checkpoint = None
    else:
        can_resume_from_checkpoint = True

    if (
        training_args.resume_from_checkpoint is None
        and training_args.do_train
        and os.path.isdir(training_args.output_dir)
        and not training_args.overwrite_output_dir
        and can_resume_from_checkpoint
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
chenych's avatar
chenych committed
351
352
353
        if last_checkpoint is None and any(
            os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
        ):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
354
355
356
357
            raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")

        if last_checkpoint is not None:
            training_args.resume_from_checkpoint = last_checkpoint
luopl's avatar
luopl committed
358
359
            logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
            logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
360
361
362
363
364
365

    if (
        finetuning_args.stage in ["rm", "ppo"]
        and finetuning_args.finetuning_type == "lora"
        and training_args.resume_from_checkpoint is not None
    ):
luopl's avatar
luopl committed
366
        logger.warning_rank0(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
367
368
369
370
371
372
373
374
375
376
377
378
379
            "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
                training_args.resume_from_checkpoint
            )
        )

    # Post-process model arguments
    if training_args.bf16 or finetuning_args.pure_bf16:
        model_args.compute_dtype = torch.bfloat16
    elif training_args.fp16:
        model_args.compute_dtype = torch.float16

    model_args.device_map = {"": get_current_device()}
    model_args.model_max_length = data_args.cutoff_len
chenych's avatar
chenych committed
380
    model_args.block_diag_attn = data_args.neat_packing
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
381
382
    data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"

chenych's avatar
chenych committed
383
    # Log on each process the small summary
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
384
    logger.info(
chenych's avatar
chenych committed
385
386
387
        "Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
            training_args.process_index,
            training_args.world_size,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
388
            training_args.device,
chenych's avatar
chenych committed
389
            training_args.parallel_mode == ParallelMode.DISTRIBUTED,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
390
391
392
393
394
395
396
397
            str(model_args.compute_dtype),
        )
    )
    transformers.set_seed(training_args.seed)

    return model_args, data_args, training_args, finetuning_args, generating_args


luopl's avatar
luopl committed
398
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
399
400
401
402
403
404
405
406
407
    model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)

    _set_transformers_logging()

    if model_args.infer_backend == "vllm":
        if finetuning_args.stage != "sft":
            raise ValueError("vLLM engine only supports auto-regressive models.")

        if model_args.quantization_bit is not None:
chenych's avatar
chenych committed
408
            raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
409
410
411
412

        if model_args.rope_scaling is not None:
            raise ValueError("vLLM engine does not support RoPE scaling.")

chenych's avatar
chenych committed
413
414
415
416
        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
            raise ValueError("vLLM only accepts a single adapter. Merge them first.")

    _verify_model_args(model_args, data_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
417
418
    _check_extra_dependencies(model_args, finetuning_args)

chenych's avatar
chenych committed
419
420
    if model_args.export_dir is not None and model_args.export_device == "cpu":
        model_args.device_map = {"": torch.device("cpu")}
chenych's avatar
chenych committed
421
422
        if data_args.cutoff_len != DataArguments().cutoff_len:  # override cutoff_len if it is not default
            model_args.model_max_length = data_args.cutoff_len
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
423
424
425
426
427
428
    else:
        model_args.device_map = "auto"

    return model_args, data_args, finetuning_args, generating_args


luopl's avatar
luopl committed
429
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
430
431
432
433
434
435
436
    model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)

    _set_transformers_logging()

    if model_args.infer_backend == "vllm":
        raise ValueError("vLLM backend is only available for API, CLI and Web.")

chenych's avatar
chenych committed
437
    _verify_model_args(model_args, data_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
438
439
440
441
442
443
444
    _check_extra_dependencies(model_args, finetuning_args)

    model_args.device_map = "auto"

    transformers.set_seed(eval_args.seed)

    return model_args, data_args, eval_args, finetuning_args