"vscode:/vscode.git/clone" did not exist on "d8827789d4e2536229bc0dcddd613d06cdd3fd7e"
parser.py 18 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19
20
21
22
23
24
import os
import sys
from typing import Any, Dict, Optional, Tuple

import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
chenych's avatar
chenych committed
25
from transformers.integrations import is_deepspeed_zero3_enabled
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26
from transformers.trainer_utils import get_last_checkpoint
chenych's avatar
chenych committed
27
from transformers.training_args import ParallelMode
luopl's avatar
luopl committed
28
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
30
from transformers.utils.versions import require_version

luopl's avatar
luopl committed
31
from ..extras import logging
chenych's avatar
chenych committed
32
from ..extras.constants import CHECKPOINT_NAMES
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
33
34
35
36
37
38
39
40
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments


luopl's avatar
luopl committed
41
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


check_dependencies()


_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]


def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
    if args is not None:
        return parser.parse_dict(args)

luopl's avatar
luopl committed
59
    if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
60
61
62
63
64
65
66
67
68
        return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        return parser.parse_json_file(os.path.abspath(sys.argv[1]))

    (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

    if unknown_args:
        print(parser.format_help())
luopl's avatar
luopl committed
69
70
        print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
        raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74

    return (*parsed_args,)


luopl's avatar
luopl committed
75
76
def _set_transformers_logging() -> None:
    transformers.utils.logging.set_verbosity_info()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
78
79
80
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()


chenych's avatar
chenych committed
81
82
83
84
85
def _verify_model_args(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    finetuning_args: "FinetuningArguments",
) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
86
87
88
89
90
91
92
    if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
        raise ValueError("Adapter is only valid for the LoRA method.")

    if model_args.quantization_bit is not None:
        if finetuning_args.finetuning_type != "lora":
            raise ValueError("Quantization is only compatible with the LoRA method.")

chenych's avatar
chenych committed
93
94
95
96
97
98
        if finetuning_args.pissa_init:
            raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")

        if model_args.resize_vocab:
            raise ValueError("Cannot resize embedding layers of a quantized model.")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
100
101
102
103
104
        if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
            raise ValueError("Cannot create new adapter upon a quantized model.")

        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
            raise ValueError("Quantized model only accepts a single adapter. Merge them first.")

chenych's avatar
chenych committed
105
    if data_args.template == "yi" and model_args.use_fast_tokenizer:
luopl's avatar
luopl committed
106
        logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
chenych's avatar
chenych committed
107
108
        model_args.use_fast_tokenizer = False

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
109
110
111
112
113
114
115
116
117

def _check_extra_dependencies(
    model_args: "ModelArguments",
    finetuning_args: "FinetuningArguments",
    training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
    if model_args.use_unsloth:
        require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")

luopl's avatar
luopl committed
118
119
120
    if model_args.enable_liger_kernel:
        require_version("liger-kernel", "To fix: pip install liger-kernel")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
121
122
123
124
    if model_args.mixture_of_depths is not None:
        require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")

    if model_args.infer_backend == "vllm":
luopl's avatar
luopl committed
125
        require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
126
127
128
129
130

    if finetuning_args.use_galore:
        require_version("galore_torch", "To fix: pip install galore_torch")

    if finetuning_args.use_badam:
chenych's avatar
chenych committed
131
132
133
134
135
136
137
        require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")

    if finetuning_args.use_adam_mini:
        require_version("adam-mini", "To fix: pip install adam-mini")

    if finetuning_args.plot_loss:
        require_version("matplotlib", "To fix: pip install matplotlib")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

    if training_args is not None and training_args.predict_with_generate:
        require_version("jieba", "To fix: pip install jieba")
        require_version("nltk", "To fix: pip install nltk")
        require_version("rouge_chinese", "To fix: pip install rouge-chinese")


def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
    parser = HfArgumentParser(_TRAIN_ARGS)
    return _parse_args(parser, args)


def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
    parser = HfArgumentParser(_INFER_ARGS)
    return _parse_args(parser, args)


def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
    parser = HfArgumentParser(_EVAL_ARGS)
    return _parse_args(parser, args)


def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
    model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)

    # Setup logging
    if training_args.should_log:
        _set_transformers_logging()

    # Check arguments
    if finetuning_args.stage != "pt" and data_args.template is None:
        raise ValueError("Please specify which `template` to use.")

chenych's avatar
chenych committed
171
172
173
174
175
176
177
178
179
    if finetuning_args.stage != "sft":
        if training_args.predict_with_generate:
            raise ValueError("`predict_with_generate` cannot be set as True except SFT.")

        if data_args.neat_packing:
            raise ValueError("`neat_packing` cannot be set as True except SFT.")

        if data_args.train_on_prompt or data_args.mask_history:
            raise ValueError("`train_on_prompt` or `mask_history` cannot be set as True except SFT.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
180
181
182
183
184
185
186

    if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
        raise ValueError("Please enable `predict_with_generate` to save model predictions.")

    if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
        raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")

chenych's avatar
chenych committed
187
188
189
    if finetuning_args.stage == "ppo":
        if not training_args.do_train:
            raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
190

chenych's avatar
chenych committed
191
192
        if model_args.shift_attn:
            raise ValueError("PPO training is incompatible with S^2-Attn.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
193

chenych's avatar
chenych committed
194
195
        if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
            raise ValueError("Unsloth does not support lora reward model.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
196

chenych's avatar
chenych committed
197
198
199
200
201
202
203
204
        if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
            raise ValueError("PPO only accepts wandb or tensorboard logger.")

    if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
        raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")

    if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
        raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
205
206
207
208

    if training_args.max_steps == -1 and data_args.streaming:
        raise ValueError("Please specify `max_steps` in streaming mode.")

chenych's avatar
chenych committed
209
210
211
212
213
214
215
216
    if training_args.do_train and data_args.dataset is None:
        raise ValueError("Please specify dataset for training.")

    if (training_args.do_eval or training_args.do_predict) and (
        data_args.eval_dataset is None and data_args.val_size < 1e-6
    ):
        raise ValueError("Please specify dataset for evaluation.")

luopl's avatar
luopl committed
217
218
219
220
221
222
    if training_args.predict_with_generate:
        if is_deepspeed_zero3_enabled():
            raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")

        if data_args.eval_dataset is None:
            raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
chenych's avatar
chenych committed
223

luopl's avatar
luopl committed
224
225
        if finetuning_args.compute_accuracy:
            raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
226
227
228
229

    if training_args.do_train and model_args.quantization_device_map == "auto":
        raise ValueError("Cannot use device map for quantized models in training.")

chenych's avatar
chenych committed
230
231
    if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
        raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
232
233

    if finetuning_args.pure_bf16:
luopl's avatar
luopl committed
234
        if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
235
236
            raise ValueError("This device does not support `pure_bf16`.")

chenych's avatar
chenych committed
237
238
        if is_deepspeed_zero3_enabled():
            raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
239
240
241
242

    if (
        finetuning_args.use_galore
        and finetuning_args.galore_layerwise
chenych's avatar
chenych committed
243
        and training_args.parallel_mode == ParallelMode.DISTRIBUTED
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
244
245
246
    ):
        raise ValueError("Distributed training does not support layer-wise GaLore.")

chenych's avatar
chenych committed
247
248
249
250
251
    if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
        if finetuning_args.badam_mode == "ratio":
            raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
        elif not is_deepspeed_zero3_enabled():
            raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
252

chenych's avatar
chenych committed
253
254
    if finetuning_args.use_galore and training_args.deepspeed is not None:
        raise ValueError("GaLore is incompatible with DeepSpeed yet.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
255
256
257
258

    if model_args.infer_backend == "vllm":
        raise ValueError("vLLM backend is only available for API, CLI and Web.")

chenych's avatar
chenych committed
259
260
261
262
    if model_args.use_unsloth and is_deepspeed_zero3_enabled():
        raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")

    if data_args.neat_packing and not data_args.packing:
luopl's avatar
luopl committed
263
        logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
chenych's avatar
chenych committed
264
265
266
        data_args.packing = True

    _verify_model_args(model_args, data_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
267
268
269
270
271
    _check_extra_dependencies(model_args, finetuning_args, training_args)

    if (
        training_args.do_train
        and finetuning_args.finetuning_type == "lora"
chenych's avatar
chenych committed
272
        and model_args.quantization_bit is None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
273
274
275
        and model_args.resize_vocab
        and finetuning_args.additional_target is None
    ):
luopl's avatar
luopl committed
276
277
278
        logger.warning_rank0(
            "Remember to add embedding layers to `additional_target` to make the added tokens trainable."
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
279
280

    if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
luopl's avatar
luopl committed
281
        logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
282
283

    if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
luopl's avatar
luopl committed
284
        logger.warning_rank0("We recommend enable mixed precision training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
285
286

    if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
luopl's avatar
luopl committed
287
288
289
        logger.warning_rank0(
            "Using GaLore with mixed precision training may significantly increases GPU memory usage."
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
290
291

    if (not training_args.do_train) and model_args.quantization_bit is not None:
luopl's avatar
luopl committed
292
        logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
293
294

    if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
luopl's avatar
luopl committed
295
        logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
296
297
298

    # Post-process training arguments
    if (
chenych's avatar
chenych committed
299
        training_args.parallel_mode == ParallelMode.DISTRIBUTED
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
300
301
302
        and training_args.ddp_find_unused_parameters is None
        and finetuning_args.finetuning_type == "lora"
    ):
luopl's avatar
luopl committed
303
        logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
304
305
306
307
308
        training_args.ddp_find_unused_parameters = False

    if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
        can_resume_from_checkpoint = False
        if training_args.resume_from_checkpoint is not None:
luopl's avatar
luopl committed
309
            logger.warning_rank0("Cannot resume from checkpoint in current stage.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
310
311
312
313
314
315
316
317
318
319
320
321
            training_args.resume_from_checkpoint = None
    else:
        can_resume_from_checkpoint = True

    if (
        training_args.resume_from_checkpoint is None
        and training_args.do_train
        and os.path.isdir(training_args.output_dir)
        and not training_args.overwrite_output_dir
        and can_resume_from_checkpoint
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
chenych's avatar
chenych committed
322
323
324
        if last_checkpoint is None and any(
            os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
        ):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
325
326
327
328
            raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")

        if last_checkpoint is not None:
            training_args.resume_from_checkpoint = last_checkpoint
luopl's avatar
luopl committed
329
330
            logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
            logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
331
332
333
334
335
336

    if (
        finetuning_args.stage in ["rm", "ppo"]
        and finetuning_args.finetuning_type == "lora"
        and training_args.resume_from_checkpoint is not None
    ):
luopl's avatar
luopl committed
337
        logger.warning_rank0(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
338
339
340
341
342
343
344
345
346
347
348
349
350
            "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
                training_args.resume_from_checkpoint
            )
        )

    # Post-process model arguments
    if training_args.bf16 or finetuning_args.pure_bf16:
        model_args.compute_dtype = torch.bfloat16
    elif training_args.fp16:
        model_args.compute_dtype = torch.float16

    model_args.device_map = {"": get_current_device()}
    model_args.model_max_length = data_args.cutoff_len
chenych's avatar
chenych committed
351
    model_args.block_diag_attn = data_args.neat_packing
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
352
353
    data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"

chenych's avatar
chenych committed
354
    # Log on each process the small summary
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
355
356
357
358
359
    logger.info(
        "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
            training_args.local_rank,
            training_args.device,
            training_args.n_gpu,
chenych's avatar
chenych committed
360
            training_args.parallel_mode == ParallelMode.DISTRIBUTED,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            str(model_args.compute_dtype),
        )
    )

    transformers.set_seed(training_args.seed)

    return model_args, data_args, training_args, finetuning_args, generating_args


def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
    model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)

    _set_transformers_logging()

    if data_args.template is None:
        raise ValueError("Please specify which `template` to use.")

    if model_args.infer_backend == "vllm":
        if finetuning_args.stage != "sft":
            raise ValueError("vLLM engine only supports auto-regressive models.")

        if model_args.quantization_bit is not None:
chenych's avatar
chenych committed
383
            raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
384
385
386
387

        if model_args.rope_scaling is not None:
            raise ValueError("vLLM engine does not support RoPE scaling.")

chenych's avatar
chenych committed
388
389
390
391
        if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
            raise ValueError("vLLM only accepts a single adapter. Merge them first.")

    _verify_model_args(model_args, data_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
392
393
    _check_extra_dependencies(model_args, finetuning_args)

chenych's avatar
chenych committed
394
395
396
    if model_args.export_dir is not None and model_args.export_device == "cpu":
        model_args.device_map = {"": torch.device("cpu")}
        model_args.model_max_length = data_args.cutoff_len
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    else:
        model_args.device_map = "auto"

    return model_args, data_args, finetuning_args, generating_args


def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
    model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)

    _set_transformers_logging()

    if data_args.template is None:
        raise ValueError("Please specify which `template` to use.")

    if model_args.infer_backend == "vllm":
        raise ValueError("vLLM backend is only available for API, CLI and Web.")

chenych's avatar
chenych committed
414
    _verify_model_args(model_args, data_args, finetuning_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
415
416
417
418
419
420
421
    _check_extra_dependencies(model_args, finetuning_args)

    model_args.device_map = "auto"

    transformers.set_seed(eval_args.seed)

    return model_args, data_args, eval_args, finetuning_args