model_args.py 21.7 KB
Newer Older
shihm's avatar
uodata  
shihm committed
1
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

luopl's avatar
luopl committed
18
import json
luopl's avatar
luopl committed
19
from dataclasses import asdict, dataclass, field, fields
shihm's avatar
uodata  
shihm committed
20
from typing import Any, Literal, Self
chenych's avatar
chenych committed
21

luopl's avatar
luopl committed
22
import torch
shihm's avatar
uodata  
shihm committed
23
from omegaconf import OmegaConf
luopl's avatar
luopl committed
24
from transformers.training_args import _convert_str_dict
chenych's avatar
chenych committed
25

chenych's avatar
chenych committed
26
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
shihm's avatar
uodata  
shihm committed
27
28
29
30
from ..extras.logging import get_logger


logger = get_logger(__name__)
chenych's avatar
chenych committed
31
32
33
34


@dataclass
class BaseModelArguments:
chenych's avatar
chenych committed
35
    r"""Arguments pertaining to the model."""
chenych's avatar
chenych committed
36

shihm's avatar
uodata  
shihm committed
37
    model_name_or_path: str | None = field(
chenych's avatar
chenych committed
38
39
40
41
42
        default=None,
        metadata={
            "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
        },
    )
shihm's avatar
uodata  
shihm committed
43
    adapter_name_or_path: str | None = field(
chenych's avatar
chenych committed
44
45
46
47
48
49
50
51
        default=None,
        metadata={
            "help": (
                "Path to the adapter weight or identifier from huggingface.co/models. "
                "Use commas to separate multiple adapters."
            )
        },
    )
shihm's avatar
uodata  
shihm committed
52
    adapter_folder: str | None = field(
chenych's avatar
chenych committed
53
54
55
        default=None,
        metadata={"help": "The folder containing the adapter weights to load."},
    )
shihm's avatar
uodata  
shihm committed
56
    cache_dir: str | None = field(
chenych's avatar
chenych committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        default=None,
        metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
    )
    resize_vocab: bool = field(
        default=False,
        metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
    )
    split_special_tokens: bool = field(
        default=False,
        metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
    )
shihm's avatar
uodata  
shihm committed
72
    add_tokens: str | None = field(
chenych's avatar
chenych committed
73
74
75
76
77
        default=None,
        metadata={
            "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
        },
    )
shihm's avatar
uodata  
shihm committed
78
    add_special_tokens: str | None = field(
chenych's avatar
chenych committed
79
80
81
        default=None,
        metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
    )
shihm's avatar
uodata  
shihm committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    new_special_tokens_config: str | None = field(
        default=None,
        metadata={
            "help": (
                "Path to YAML config with special token descriptions for semantic initialization. "
                "If set, this takes precedence over add_special_tokens. "
                "YAML format: {'<token>': 'description text', ...}"
            )
        },
    )
    init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
        default="noise_init",
        metadata={
            "help": (
                "Initialization method for new special tokens: "
                "'noise_init' (default, random noise around mean), "
                "'desc_init' (semantic initialization from descriptions), "
                "'desc_init_w_noise' (semantic + random noise). "
                "Note: 'desc_init' methods require new_special_tokens_config."
            )
        },
    )
chenych's avatar
chenych committed
104
105
106
107
108
109
110
111
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    low_cpu_mem_usage: bool = field(
        default=True,
        metadata={"help": "Whether or not to use memory-efficient model loading."},
    )
shihm's avatar
uodata  
shihm committed
112
    rope_scaling: RopeScaling | None = field(
chenych's avatar
chenych committed
113
114
115
116
117
118
119
120
121
122
123
        default=None,
        metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
    )
    flash_attn: AttentionFunction = field(
        default=AttentionFunction.AUTO,
        metadata={"help": "Enable FlashAttention for faster training and inference."},
    )
    shift_attn: bool = field(
        default=False,
        metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
    )
shihm's avatar
uodata  
shihm committed
124
    mixture_of_depths: Literal["convert", "load"] | None = field(
chenych's avatar
chenych committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        default=None,
        metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
    )
    use_unsloth: bool = field(
        default=False,
        metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
    )
    use_unsloth_gc: bool = field(
        default=False,
        metadata={"help": "Whether or not to use unsloth's gradient checkpointing (no need to install unsloth)."},
    )
    enable_liger_kernel: bool = field(
        default=False,
        metadata={"help": "Whether or not to enable liger kernel for faster training."},
    )
shihm's avatar
uodata  
shihm committed
140
    moe_aux_loss_coef: float | None = field(
chenych's avatar
chenych committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        default=None,
        metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
    )
    disable_gradient_checkpointing: bool = field(
        default=False,
        metadata={"help": "Whether or not to disable gradient checkpointing."},
    )
    use_reentrant_gc: bool = field(
        default=True,
        metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
    )
    upcast_layernorm: bool = field(
        default=False,
        metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
    )
    upcast_lmhead_output: bool = field(
        default=False,
        metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
    )
    train_from_scratch: bool = field(
        default=False,
        metadata={"help": "Whether or not to randomly initialize the model weights."},
    )
    infer_backend: EngineName = field(
        default=EngineName.HF,
        metadata={"help": "Backend engine used at inference."},
    )
    offload_folder: str = field(
        default="offload",
        metadata={"help": "Path to offload model weights."},
    )
shihm's avatar
uodata  
shihm committed
172
    use_kv_cache: bool = field(
chenych's avatar
chenych committed
173
174
175
        default=True,
        metadata={"help": "Whether or not to use KV cache in generation."},
    )
shihm's avatar
uodata  
shihm committed
176
177
178
179
    use_v1_kernels: bool | None = field(
        default=False,
        metadata={"help": "Whether or not to use high-performance kernels in training."},
    )
chenych's avatar
chenych committed
180
181
182
183
    infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
        default="auto",
        metadata={"help": "Data type for model weights and activations at inference."},
    )
shihm's avatar
uodata  
shihm committed
184
    hf_hub_token: str | None = field(
chenych's avatar
chenych committed
185
186
187
        default=None,
        metadata={"help": "Auth token to log in with Hugging Face Hub."},
    )
shihm's avatar
uodata  
shihm committed
188
    ms_hub_token: str | None = field(
chenych's avatar
chenych committed
189
190
191
        default=None,
        metadata={"help": "Auth token to log in with ModelScope Hub."},
    )
shihm's avatar
uodata  
shihm committed
192
    om_hub_token: str | None = field(
chenych's avatar
chenych committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        default=None,
        metadata={"help": "Auth token to log in with Modelers Hub."},
    )
    print_param_status: bool = field(
        default=False,
        metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
    )

    def __post_init__(self):
        if self.model_name_or_path is None:
            raise ValueError("Please provide `model_name_or_path`.")

        if self.split_special_tokens and self.use_fast_tokenizer:
            raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")

        if self.adapter_name_or_path is not None:  # support merging multiple lora weights
            self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]

chenych's avatar
chenych committed
215
216
217
        if self.add_tokens is not None:  # support multiple tokens
            self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]

shihm's avatar
uodata  
shihm committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        # Process special tokens with priority: new_special_tokens_config > add_special_tokens
        if self.new_special_tokens_config is not None:
            # Priority 1: Load from YAML config (extracts both tokens and descriptions)
            try:
                cfg = OmegaConf.load(self.new_special_tokens_config)
                token_descriptions = OmegaConf.to_container(cfg)

                if not isinstance(token_descriptions, dict):
                    raise ValueError(
                        f"YAML config must be a dictionary mapping tokens to descriptions. "
                        f"Got: {type(token_descriptions)}"
                    )

                # Extract token list from config keys
                extracted_tokens = list(token_descriptions.keys())

                # Warn if both are set
                if self.add_special_tokens is not None:
                    logger.warning_rank0(
                        "Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
                        f"Using tokens from config: {extracted_tokens}"
                    )

                # Override add_special_tokens with extracted tokens (as list)
                self.add_special_tokens = extracted_tokens

                # Store descriptions internally for later use (internal attribute)
                self._special_token_descriptions = token_descriptions

                logger.info_rank0(
                    f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
                    f"{self.new_special_tokens_config}"
                )

            except Exception as e:
                logger.error_rank0(
                    f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
                )
                raise

        elif self.add_special_tokens is not None:
            # Priority 2: Use simple comma-separated string (no descriptions)
chenych's avatar
chenych committed
260
            self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
shihm's avatar
uodata  
shihm committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            self._special_token_descriptions = None

        else:
            # No special tokens to add
            self._special_token_descriptions = None

        # Validate init method
        if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
            if self._special_token_descriptions is None:
                logger.warning_rank0(
                    f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
                    "Falling back to 'noise_init'"
                )
                self.init_special_tokens = "noise_init"
chenych's avatar
chenych committed
275

chenych's avatar
chenych committed
276

luopl's avatar
luopl committed
277
278
@dataclass
class QuantizationArguments:
chenych's avatar
chenych committed
279
    r"""Arguments pertaining to the quantization method."""
luopl's avatar
luopl committed
280

chenych's avatar
chenych committed
281
282
    quantization_method: QuantizationMethod = field(
        default=QuantizationMethod.BNB,
luopl's avatar
luopl committed
283
284
        metadata={"help": "Quantization method to use for on-the-fly quantization."},
    )
shihm's avatar
uodata  
shihm committed
285
    quantization_bit: int | None = field(
luopl's avatar
luopl committed
286
287
288
289
290
291
292
293
294
295
296
        default=None,
        metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
    )
    quantization_type: Literal["fp4", "nf4"] = field(
        default="nf4",
        metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
    )
    double_quantization: bool = field(
        default=True,
        metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
    )
shihm's avatar
uodata  
shihm committed
297
    quantization_device_map: Literal["auto"] | None = field(
luopl's avatar
luopl committed
298
299
300
        default=None,
        metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
    )
shihm's avatar
uodata  
shihm committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    fp8: bool = field(
        default=False,
        metadata={
            "help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
            "Requires PyTorch 2.7+ and Hopper architecture GPUs."
        },
    )
    fp8_backend: str = field(
        default="auto",
        metadata={
            "help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
        },
    )
    fp8_enable_fsdp_float8_all_gather: bool = field(
        default=False,
        metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
    )
luopl's avatar
luopl committed
318
319
320
321


@dataclass
class ProcessorArguments:
chenych's avatar
chenych committed
322
    r"""Arguments pertaining to the image processor."""
luopl's avatar
luopl committed
323

chenych's avatar
chenych committed
324
325
326
    image_max_pixels: int = field(
        default=768 * 768,
        metadata={"help": "The maximum number of pixels of image inputs."},
luopl's avatar
luopl committed
327
    )
chenych's avatar
chenych committed
328
329
330
331
    image_min_pixels: int = field(
        default=32 * 32,
        metadata={"help": "The minimum number of pixels of image inputs."},
    )
chenych's avatar
chenych committed
332
333
334
335
    image_do_pan_and_scan: bool = field(
        default=False,
        metadata={"help": "Use pan and scan to process image for gemma3."},
    )
chenych's avatar
chenych committed
336
337
338
339
    crop_to_patches: bool = field(
        default=False,
        metadata={"help": "Whether to crop the image to patches for internvl."},
    )
chenych's avatar
chenych committed
340
341
342
343
344
345
346
    video_max_pixels: int = field(
        default=256 * 256,
        metadata={"help": "The maximum number of pixels of video inputs."},
    )
    video_min_pixels: int = field(
        default=16 * 16,
        metadata={"help": "The minimum number of pixels of video inputs."},
luopl's avatar
luopl committed
347
348
349
350
351
352
    )
    video_fps: float = field(
        default=2.0,
        metadata={"help": "The frames to sample per second for video inputs."},
    )
    video_maxlen: int = field(
chenych's avatar
chenych committed
353
        default=128,
luopl's avatar
luopl committed
354
355
        metadata={"help": "The maximum number of sampled frames for video inputs."},
    )
chenych's avatar
chenych committed
356
357
358
359
    use_audio_in_video: bool = field(
        default=False,
        metadata={"help": "Whether or not to use audio in video inputs."},
    )
chenych's avatar
chenych committed
360
361
362
363
364
365
366
367
368
369
370
    audio_sampling_rate: int = field(
        default=16000,
        metadata={"help": "The sampling rate of audio inputs."},
    )

    def __post_init__(self):
        if self.image_max_pixels < self.image_min_pixels:
            raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")

        if self.video_max_pixels < self.video_min_pixels:
            raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
luopl's avatar
luopl committed
371
372
373
374


@dataclass
class ExportArguments:
chenych's avatar
chenych committed
375
    r"""Arguments pertaining to the model export."""
luopl's avatar
luopl committed
376

shihm's avatar
uodata  
shihm committed
377
    export_dir: str | None = field(
luopl's avatar
luopl committed
378
379
380
381
        default=None,
        metadata={"help": "Path to the directory to save the exported model."},
    )
    export_size: int = field(
chenych's avatar
chenych committed
382
        default=5,
luopl's avatar
luopl committed
383
384
385
386
387
388
        metadata={"help": "The file shard size (in GB) of the exported model."},
    )
    export_device: Literal["cpu", "auto"] = field(
        default="cpu",
        metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
    )
shihm's avatar
uodata  
shihm committed
389
    export_quantization_bit: int | None = field(
luopl's avatar
luopl committed
390
391
392
        default=None,
        metadata={"help": "The number of bits to quantize the exported model."},
    )
shihm's avatar
uodata  
shihm committed
393
    export_quantization_dataset: str | None = field(
luopl's avatar
luopl committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        default=None,
        metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
    )
    export_quantization_nsamples: int = field(
        default=128,
        metadata={"help": "The number of samples used for quantization."},
    )
    export_quantization_maxlen: int = field(
        default=1024,
        metadata={"help": "The maximum length of the model inputs used for quantization."},
    )
    export_legacy_format: bool = field(
        default=False,
        metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
    )
shihm's avatar
uodata  
shihm committed
409
    export_hub_model_id: str | None = field(
luopl's avatar
luopl committed
410
411
412
413
        default=None,
        metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
    )

chenych's avatar
chenych committed
414
415
416
417
    def __post_init__(self):
        if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
            raise ValueError("Quantization dataset is necessary for exporting.")

luopl's avatar
luopl committed
418
419
420

@dataclass
class VllmArguments:
chenych's avatar
chenych committed
421
    r"""Arguments pertaining to the vLLM worker."""
luopl's avatar
luopl committed
422
423

    vllm_maxlen: int = field(
luopl's avatar
luopl committed
424
        default=4096,
luopl's avatar
luopl committed
425
426
427
        metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
    )
    vllm_gpu_util: float = field(
chenych's avatar
chenych committed
428
        default=0.7,
luopl's avatar
luopl committed
429
430
431
432
433
434
435
436
437
438
        metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
    )
    vllm_enforce_eager: bool = field(
        default=False,
        metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
    )
    vllm_max_lora_rank: int = field(
        default=32,
        metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
    )
shihm's avatar
uodata  
shihm committed
439
    vllm_config: dict | str | None = field(
luopl's avatar
luopl committed
440
441
442
        default=None,
        metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
    )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
443

chenych's avatar
chenych committed
444
445
446
447
    def __post_init__(self):
        if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
            self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
448
449

@dataclass
chenych's avatar
chenych committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
class SGLangArguments:
    r"""Arguments pertaining to the SGLang worker."""

    sglang_maxlen: int = field(
        default=4096,
        metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."},
    )
    sglang_mem_fraction: float = field(
        default=0.7,
        metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."},
    )
    sglang_tp_size: int = field(
        default=-1,
        metadata={"help": "Tensor parallel size for the SGLang engine."},
    )
shihm's avatar
uodata  
shihm committed
465
    sglang_config: dict | str | None = field(
chenych's avatar
chenych committed
466
467
468
        default=None,
        metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
    )
chenych's avatar
chenych committed
469
470
471
472
473
474
    sglang_lora_backend: Literal["triton", "flashinfer"] = field(
        default="triton",
        metadata={
            "help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
        },
    )
chenych's avatar
chenych committed
475
476
477
478
479
480

    def __post_init__(self):
        if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
            self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))


shihm's avatar
uodata  
shihm committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
@dataclass
class KTransformersArguments:
    r"""Arguments pertaining to the KT training."""

    use_kt: bool = field(
        default=False,
        metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
    )
    kt_optimize_rule: str | None = field(
        default=None,
        metadata={
            "help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
        },
    )
    cpu_infer: int | None = field(
        default=32,
        metadata={"help": "Number Of CPU Cores Used For Computation."},
    )
    chunk_size: int | None = field(
        default=8192,
        metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
    )
    mode: str | None = field(
        default="normal",
        metadata={"help": "Normal Or Long_Context For Llama Models."},
    )

    kt_maxlen: int = field(
        default=4096,
        metadata={"help": "Maximum Sequence (Prompt + Response) Length Of The KT Engine."},
    )
    kt_use_cuda_graph: bool = field(
        default=True,
        metadata={"help": "Whether To Use CUDA Graphs For The KT Engine."},
    )
    kt_mode: str = field(
        default="normal",
        metadata={"help": "Normal Or Long_Context Mode For The KT Engine."},
    )
    kt_force_think: bool = field(
        default=False,
        metadata={"help": "Force-Think Toggle For The KT Engine."},
    )


chenych's avatar
chenych committed
526
527
@dataclass
class ModelArguments(
shihm's avatar
uodata  
shihm committed
528
529
530
531
532
533
534
    SGLangArguments,
    VllmArguments,
    KTransformersArguments,
    ExportArguments,
    ProcessorArguments,
    QuantizationArguments,
    BaseModelArguments,
chenych's avatar
chenych committed
535
536
):
    r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
chenych's avatar
chenych committed
537
538

    The class on the most right will be displayed first.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
539
540
    """

shihm's avatar
uodata  
shihm committed
541
    compute_dtype: torch.dtype | None = field(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
542
        default=None,
luopl's avatar
luopl committed
543
544
        init=False,
        metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
545
    )
shihm's avatar
uodata  
shihm committed
546
    device_map: str | dict[str, Any] | None = field(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
547
        default=None,
luopl's avatar
luopl committed
548
549
        init=False,
        metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
550
    )
shihm's avatar
uodata  
shihm committed
551
    model_max_length: int | None = field(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
552
        default=None,
luopl's avatar
luopl committed
553
554
        init=False,
        metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
555
    )
luopl's avatar
luopl committed
556
    block_diag_attn: bool = field(
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
557
        default=False,
luopl's avatar
luopl committed
558
559
        init=False,
        metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
560
561
562
    )

    def __post_init__(self):
chenych's avatar
chenych committed
563
        BaseModelArguments.__post_init__(self)
chenych's avatar
chenych committed
564
        ProcessorArguments.__post_init__(self)
chenych's avatar
chenych committed
565
566
        ExportArguments.__post_init__(self)
        VllmArguments.__post_init__(self)
chenych's avatar
chenych committed
567
        SGLangArguments.__post_init__(self)
luopl's avatar
luopl committed
568

chenych's avatar
chenych committed
569
    @classmethod
luopl's avatar
luopl committed
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def copyfrom(cls, source: "Self", **kwargs) -> "Self":
        init_args, lazy_args = {}, {}
        for attr in fields(source):
            if attr.init:
                init_args[attr.name] = getattr(source, attr.name)
            else:
                lazy_args[attr.name] = getattr(source, attr.name)

        init_args.update(kwargs)
        result = cls(**init_args)
        for name, value in lazy_args.items():
            setattr(result, name, value)

        return result
luopl's avatar
luopl committed
584

chenych's avatar
chenych committed
585
    def to_dict(self) -> dict[str, Any]:
luopl's avatar
luopl committed
586
587
588
        args = asdict(self)
        args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
        return args