config.py 3.61 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor config
"""

chenych's avatar
update  
chenych committed
18
import os
chenych's avatar
chenych committed
19
20
21
22
23
24
25
26
27
28
29
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple


@dataclass
class ModelConfig:
    model_path: Optional[str] = None
    tokenizer_path: Optional[str] = None
    override_config: Dict[str, Any] = field(default_factory=dict)
    enable_gradient_checkpointing: bool = True
    trust_remote_code: bool = True
chenych's avatar
chenych committed
30
    freeze_vision_tower: bool = False
chenych's avatar
chenych committed
31
32
33
34
35

    def post_init(self):
        if self.tokenizer_path is None:
            self.tokenizer_path = self.model_path

chenych's avatar
update  
chenych committed
36
37
38
39
40
41
        if self.model_path is not None and os.path.exists(self.model_path):
            self.model_path = os.path.abspath(self.model_path)

        if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
            self.tokenizer_path = os.path.abspath(self.tokenizer_path)

chenych's avatar
chenych committed
42
43
44
45
46
47

@dataclass
class OptimConfig:
    lr: float = 1e-6
    betas: Tuple[float, float] = (0.9, 0.999)
    weight_decay: float = 1e-2
chenych's avatar
chenych committed
48
49
    strategy: str = "adamw"
    lr_warmup_ratio: float = 0.0
chenych's avatar
chenych committed
50
51
52
53
54
55
56
57
58
    min_lr_ratio: Optional[float] = None
    warmup_style: str = "constant"
    """auto keys"""
    training_steps: int = field(default=-1, init=False)


@dataclass
class FSDPConfig:
    enable_full_shard: bool = True
chenych's avatar
chenych committed
59
60
61
    enable_cpu_offload: bool = False
    enable_rank0_init: bool = False
    use_orig_params: bool = False
chenych's avatar
chenych committed
62
    torch_dtype: Optional[str] = None
chenych's avatar
chenych committed
63
    fsdp_size: int = -1
chenych's avatar
chenych committed
64
65
66
67
68
69
70
    mp_param_dtype: str = "bf16"
    mp_reduce_dtype: str = "fp32"
    mp_buffer_dtype: str = "fp32"


@dataclass
class OffloadConfig:
chenych's avatar
chenych committed
71
72
    offload_params: bool = False
    offload_optimizer: bool = False
chenych's avatar
chenych committed
73
74
75
76
77
78


@dataclass
class ActorConfig:
    strategy: str = "fsdp"
    global_batch_size: int = 256
chenych's avatar
chenych committed
79
80
    micro_batch_size_per_device_for_update: int = 4
    micro_batch_size_per_device_for_experience: int = 16
chenych's avatar
chenych committed
81
    max_grad_norm: float = 1.0
chenych's avatar
Update  
chenych committed
82
83
84
    clip_ratio_low: float = 0.2
    clip_ratio_high: float = 0.3
    clip_ratio_dual: float = 3.0
chenych's avatar
chenych committed
85
86
87
    ppo_epochs: int = 1
    padding_free: bool = False
    ulysses_sequence_parallel_size: int = 1
chenych's avatar
chenych committed
88
    use_torch_compile: bool = True
chenych's avatar
chenych committed
89
90
91
92
93
94
    model: ModelConfig = field(default_factory=ModelConfig)
    optim: OptimConfig = field(default_factory=OptimConfig)
    fsdp: FSDPConfig = field(default_factory=FSDPConfig)
    offload: OffloadConfig = field(default_factory=OffloadConfig)
    """auto keys"""
    global_batch_size_per_device: int = field(default=-1, init=False)
chenych's avatar
chenych committed
95
96
97
98
    disable_kl: bool = field(default=False, init=False)
    use_kl_loss: bool = field(default=False, init=False)
    kl_penalty: str = field(default="kl", init=False)
    kl_coef: float = field(default=0.0, init=False)
chenych's avatar
chenych committed
99
100
101
102
103


@dataclass
class RefConfig:
    strategy: str = "fsdp"
chenych's avatar
chenych committed
104
    fsdp: FSDPConfig = field(default_factory=FSDPConfig)
chenych's avatar
chenych committed
105
106
107
108
    offload: OffloadConfig = field(default_factory=OffloadConfig)
    """auto keys"""
    micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
    padding_free: bool = field(default=False, init=False)
chenych's avatar
chenych committed
109
110
    ulysses_sequence_parallel_size: int = field(default=1, init=False)
    use_torch_compile: bool = field(default=True, init=False)