arguments.py 2.67 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from dataclasses import dataclass, field
from typing import Optional

import transformers

from ovis.util.utils import rankN_print


@dataclass
class ModelArguments:
    llm_name_or_path: Optional[str] = field(default=None)
    vit_name_or_path: Optional[str] = field(default=None)
    visual_vocab_size: int = field(default=65536)
    conversation_formatter_class: str = field(default=None)
    attn_implementation: Optional[str] = field(default=None)
    accepts_loss_kwargs: bool = field(default=True)
    vit_hidden_stride: int = field(default=2)
    vit_window_size: int = field(default=112)
    vit_temporal_patch_size: int = field(default=1)
    vit_fullatt_block_indexes: Optional[str] = field(default=None)
    vit_preserve_original_pe: Optional[bool] = field(default=True)
    vit_use_rope: Optional[bool] = field(default=True)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    data_info_version: Optional[str] = field(default=None)
    data_name: Optional[str] = field(default=None)  # a|b|c
    data_type: Optional[str] = field(default=None)  # caption, conversation
    ovis_pretrained_path: Optional[str] = field(default=None)
    stage: Optional[int] = field(default=None)
    train_modules: Optional[str] = field(default=None)
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    save_safetensors: bool = field(default=True)
    monitor_step: int = field(default=100)
    model_init_seed: int = field(default=0)
    multimodal_max_length: int = field(default=4096)
    text_max_length: Optional[int] = field(default=4096)
    min_frames: int = field(default=8)
    max_frames: int = field(default=8)
    overall_ratio: Optional[str] = field(default=None)
    mix_data_name: Optional[str] = field(default=None)
    mix_ratio: Optional[float] = field(default=None)
    min_lr_rate: Optional[float] = field(default=None)
    single_image_min_pixels: int = field(default=448*448)
    single_image_max_pixels: int = field(default=1792*1344)
    multiple_image_min_pixels: int = field(default=448*448)
    multiple_image_max_pixels: int = field(default=448*448)
    video_min_pixels: int = field(default=448*448)
    video_max_pixels: int = field(default=448*448)

    def __post_init__(self):
        if self.min_lr_rate is not None:
            self.lr_scheduler_kwargs = {
                "min_lr_rate": self.min_lr_rate
            }
        if self.gradient_checkpointing:
            self.gradient_checkpointing_kwargs = {"use_reentrant": False}
        if self.stage < 3:
            self.save_safetensors = False
        super().__post_init__()
        assert self.model_init_seed != self.seed, "`model_init_seed` should be different from `seed`"