configuration.py 16.7 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import argparse
import json
import logging
import os
import time
from dataclasses import dataclass

logger = logging.getLogger("AllamoConfiguration")

@dataclass
class AllamoConfiguration:

    load_configuration: bool = True
    init_from: str = 'scratch'
    checkpoint_path: str = None
    seed: int = 1337
    data_dir: str = 'data'
    out_dir: str = 'out'
    log_checkpoint_md5_on_load: bool = False
    log_checkpoint_md5_on_epoch: bool = False
    ignore_last_checkpoint_backup: bool = False
    checkpoint_interval: int = 1000
    save_optimizer_checkpoint: bool = True
    optimizer_checkpoint_interval: int = None
    save_best_checkpoint: bool = True
    save_checkpoint_on_dataset_reload: bool = False
    distributed_checkpoint: bool = False
    config_override_check_interval: int = None
    config_override_path: str = None
    eval_interval: int = 1000
    eval_iters: int = 200
    eval_only: bool = False
    log_interval: int = 1
    vocab_size: int = 31980
    tiktoken_tokenizer_name: str = None
    hf_tokenizer_path: str = None
    wandb_log: bool = False
    wandb_project: str = 'allamo'
    wandb_run_name: str = 'allamo-run-' + str(time.time())
    gradient_checkpointing: bool = False
    gradient_accumulation_steps: int = 8
    batch_size: int = 64
    block_size: int = 1024
    sliding_window: int = None
    dataset: str = None
    dataset_train_files: str = None
    dataset_validation_files: str = None
    dataset_train_file_prefix: str = 'train.'
    dataset_validation_file_prefix: str = 'val.'
    dataset_train_processed_files_count: int = 0
    dataset_seq_train: bool = True
    dataset_seq_train_start: int = None
    dataset_buffer: bool = False
    batch_size_initial: int = 2
    batch_size_max_iter: int = 2000
    batch_size_schedule: bool = False
    batch_size_max: int = 64
    grad_accum_initial: int = 2
    grad_accum_max_iter: int = 2000
    grad_accum_schedule: bool = False
    grad_accum_max: int = 8
    rope_freq_base: int = 10000
    rope_freq_scale: float = 1.0
    n_layer: int = 12
    n_head: int = 12
    head_size: int = 64
    num_kv_heads: int = None
    n_embd: int = 768
    intermediate_size: int = None
    dropout: float = 0.0 
    bias: bool = False 
    multiple_of: int = 256
    norm_eps: float = 1e-5
    learning_rate: float = 6e-4
    num_train_epochs: int = None
    max_iters: int = 600000
    weight_decay: float = 1e-1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0 
    decay_lr: bool = True
    warmup_iters: int = 2000
    lr_decay_iters: int = 600000
    lr_decay_reset_iters: int = 60000
    min_lr: float = 6e-5 
    backend: str = 'nccl' 
    device: str = 'cuda' 
    dtype: str = 'float16'
    compile: bool = False
    compile_mode: str = 'default'
    mfu_flops_peak: float = -1.0
    ignore_index: int = -100
    pad_token_id: int = -1
    weighted_loss: bool = False
    weighted_loss_method: str = 'allamo'
    adaptive_learning_rate: bool = False
    fsdp_sharding_strategy: str = 'FULL_SHARD'
    epoch_completion_hook_program: str = None
    regular_checkpoint_hook_program: str = None
    dpo_chosen_beta: float = 0.5
    dpo_rejected_beta: float = 0.1
    dpo_penalty_lambda: float = 50.0
    reference_checkpoint_name: str = 'ref_ckpt'
    training_type: str = 'pre'
    attention_implementation: str = 'sdpa'
    tensor_parallel_degree: int = 1
    
    # inference params
    prompt: str = "\n" 
    num_samples: int = 1 
    max_new_tokens: int = 50 
    temperature: float = 0.8 
    top_k: int = 100
    
    def __post_init__(self):
        if self.load_configuration:
            self.load_values()
    
    def load_values(self):
        parser = argparse.ArgumentParser(description='Allamo allows you to train and evaluate LLaMA-based models.')
        parser.add_argument('--config', help='Path to a json configuration file')
        parser.add_argument('--init_from', type=str, choices=['scratch', 'resume', 'resume_last'], help='Start from scratch or resume from best or last checkpoint')
        parser.add_argument('--checkpoint_path', type=str, help='Custom input checkpoint path')
        parser.add_argument('--seed', type=int, help='The desired seed for generating random numbers')
        parser.add_argument('--data_dir', type=str, help='Directory where datasets exist')
        parser.add_argument('--out_dir', type=str, help='Output directory for checkpoints')
        parser.add_argument('--log_checkpoint_md5_on_load', type=bool, help='When loading a checkpoint, log its MD5 checksum')
        parser.add_argument('--log_checkpoint_md5_on_epoch', type=bool, help='When saving a checkpoint at the end of an epoch, log its MD5 checksum')
        parser.add_argument('--ignore_last_checkpoint_backup', type=bool, help='Ignores preserving a copy of the last checkpoint version by overwriting it')
        parser.add_argument('--checkpoint_interval', type=int, help='Number of iterations between checkpoints where the state of the model is saved')
        parser.add_argument('--save_optimizer_checkpoint', type=bool, help='Enable saving optimizer checkpoint')
        parser.add_argument('--optimizer_checkpoint_interval', type=int, help='Number of iterations between checkpoints where the state of the optimizer is saved. The same as checkpoint_interval, if not specified')
        parser.add_argument('--save_best_checkpoint', type=bool, help='Enable saving the best checkpoint when evaluating model')
        parser.add_argument('--save_checkpoint_on_dataset_reload', type=bool, help='Enable model checkpoint saving on dataset reload')
        parser.add_argument('--distributed_checkpoint', type=bool, help='Use PyTorch Distributed Checkpoint (DCP)')
        parser.add_argument('--config_override_check_interval', type=int, help='Number of iterations for checking override configuration. Feature disabled if not specified.')
        parser.add_argument('--config_override_path', type=str, help='Specifies the location of the configuration override file')
        parser.add_argument('--eval_interval', type=int, help='Number of iterations when evaluating model')
        parser.add_argument('--eval_iters', type=int, help='Number of iterations when evaluating')
        parser.add_argument('--eval_only', type=bool, help='Exit right after the first evaluation. Indicates no training.')
        parser.add_argument('--log_interval', type=int, help='Number of iterations when training loss is logged')
        parser.add_argument('--vocab_size', type=int, help='Vacabulary size. Might be overwritten by checkpoint')
        parser.add_argument('--tiktoken_tokenizer_name', type=str, help='Tiktoken tokenizer name. Might be overwritten by checkpoint')
        parser.add_argument('--hf_tokenizer_path', type=str, help='HuggingFace tokenizer path. Might be overwritten by checkpoint')
        parser.add_argument('--wandb_log', type=bool, help='Enable logging to wandb')
        parser.add_argument('--wandb_project', type=str, help='Wandb project name')
        parser.add_argument('--wandb_run_name', type=str, help='Wandb run name')
        parser.add_argument('--gradient_checkpointing', type=bool, help='Enable gradient checkpointing')
        parser.add_argument('--gradient_accumulation_steps', type=int, help='Help simulating larger batch sizes')
        parser.add_argument('--batch_size', type=int, help='Batch size')
        parser.add_argument('--sliding_window', type=int, help='Sliding window attention window size')
        parser.add_argument('--block_size', type=int, help='The maximum sequence length that this model might ever be used with')
        parser.add_argument('--dataset', type=str, help='The name of the dataset directory within the data_dir')
        parser.add_argument('--dataset_train_files', type=str, help='Comma-separated list of training dataset files to use')
        parser.add_argument('--dataset_validation_files', type=str, help='Comma-separated list of validation dataset files to use')
        parser.add_argument('--dataset_train_file_prefix', type=str, help='Custom prefix for training dataset files')
        parser.add_argument('--dataset_validation_file_prefix', type=str, help='Custom prefix for validation dataset files')
        parser.add_argument('--dataset_train_processed_files_count', type=int, help='The number of files already processed in the training dataset')
        parser.add_argument('--dataset_seq_train', type=bool, help='Iterate dataset sequentially')
        parser.add_argument('--dataset_seq_train_start', type=int, help='Position in tokens to start with')
        parser.add_argument('--dataset_buffer', type=bool, help='Enable buffer for dataset samples')
        parser.add_argument('--batch_size_initial', type=int, help='Initial batch_size value')
        parser.add_argument('--batch_size_max_iter', help='Number of iterations to reach maximum batch_size value')
        parser.add_argument('--batch_size_schedule', type=bool, help='Enable linear batch_size scheduler')
        parser.add_argument('--grad_accum_initial', type=int, help='Initial gradient_accumulation_steps value')
        parser.add_argument('--grad_accum_max_iter', type=int, help='Number of iterations to reach maximum gradient_accumulation_steps value')
        parser.add_argument('--grad_accum_schedule', type=bool, help='Enable linear gradient_accumulation_steps scheduler')
        parser.add_argument('--rope_freq_base', type=int, help='RoPE base frequency')
        parser.add_argument('--rope_freq_scale', type=int, help='RoPE frequency scaling factor')
        parser.add_argument('--n_layer', type=int, help='Number of layers')
        parser.add_argument('--n_head', type=int, help='Number of heads')
        parser.add_argument('--head_size', type=int, help='Often calculated as n_embd/n_head')
        parser.add_argument('--num_kv_heads', type=int, help='Number of key-value heads')
        parser.add_argument('--n_embd', type=int, help='Number of model dimensions')
        parser.add_argument('--intermediate_size', type=int, help='Dimension of the MLP representations')
        parser.add_argument('--dropout', type=float, help='Enable dropouts globally. Disabled when 0')
        parser.add_argument('--bias', type=bool, help='Enable bias globally. Helpful in finetuning process')
        parser.add_argument('--multiple_of', type=int, help='Make SwiGLU hidden layer size multiple of large power of 2. Used only when intermediate_size is not specified')
        parser.add_argument('--norm_eps', type=float, help='RMSNorm normalizing function param')
        parser.add_argument('--learning_rate', type=float, help='Learning rate to start with')
        parser.add_argument('--num_train_epochs', type=int, help='Total number of training epochs to perform')
        parser.add_argument('--max_iters', type=int, help='Total number of training iterations')
        parser.add_argument('--weight_decay', type=float, help='Max learning rate')
        parser.add_argument('--beta1', type=float, help='Adamw optimizer Beta1 param')
        parser.add_argument('--beta2', type=float, help='Adamw optimizer Beta2 param')
        parser.add_argument('--grad_clip', type=float, help='Clip gradients at this value. Disabled when 0.')
        parser.add_argument('--decay_lr', type=bool, help='Whether to decay the learning rate')
        parser.add_argument('--warmup_iters', type=int, help='Learning rate is calculated linearly for warmup_iters steps')
        parser.add_argument('--lr_decay_iters', type=int, help='Learning rate decay iterations. When exceeded, the min_lr is used')
        parser.add_argument('--lr_decay_reset_iters', type=int, help='Number of iterations for the learning rate decay restart')
        parser.add_argument('--min_lr', type=float, help='Minimum learning rate')
        parser.add_argument('--backend', type=str, help='Specifies one of three built-in backends: nccl, gloo, mpi')
        parser.add_argument('--device', type=str, help='"cpu", "cuda", "cuda:0", "cuda:1" etc., or try "mps" on macbooks')
        parser.add_argument('--dtype', type=str, choices=['float32', 'bfloat16', 'bfloat16-true', 'float16'], help='Type of tensor to be used in the model')
        parser.add_argument('--compile', type=bool, help='Whether to use PyTorch 2.0 to compile the model to be faster')
        parser.add_argument('--compile_mode', type=str, choices=['default', 'reduce-overhead', 'max-autotune'], help='Specifies what the PyTorch compiler should be optimizing while compiling')
        parser.add_argument('--mfu_flops_peak', type=float, help="Specifies the MFU's peak performance in FLOPs. A default value of -1 disables MFU estimation")
        parser.add_argument('--ignore_index', type=int, help="Specifies a target value that is ignored and does not contribute to the input gradient")
        parser.add_argument('--pad_token_id', type=float, help="Enables padding and specifies the token id used for padding in sequences")
        parser.add_argument('--weighted_loss', type=bool, help='Whether to use weighted loss if available')
        parser.add_argument('--weighted_loss_method', type=str, choices=['allamo', 'openchat'], help='How weighted loss is calculated')
        parser.add_argument('--adaptive_learning_rate', type=bool, help='Whether to use adaptive learning rate')
        parser.add_argument('--fsdp_sharding_strategy', type=str, choices=['FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2', 'SHARD_GRAD_OP', 'NO_SHARD'], help='FSDP sharding strategy')
        parser.add_argument('--epoch_completion_hook_program', type=str, help='Path to the program/script to be executed after the epoch ends and the checkpoint is saved')
        parser.add_argument('--regular_checkpoint_hook_program', type=str, help='Path to the program/script to be executed after the regualar checkpoint is saved')
        parser.add_argument('--dpo_chosen_beta', type=float, help='Temperature parameter for the chosen part of the DPO loss, typically something in the range of 0.1 to 0.5')
        parser.add_argument('--dpo_rejected_beta', type=float, help='Temperature parameter for the rejected part of the DPO loss, typically something in the range of 0.1 to 0.5')
        parser.add_argument('--dpo_penalty_lambda', type=float, help='Temperature parameter for penalty-positive in the DPO loss, typically in the range of 1 to 100')
        parser.add_argument('--reference_checkpoint_name', type=str, help='Checkpoint name for the reference model')
        parser.add_argument('--training_type', type=str, choices=['pre', 'sft', 'dpo'], help='Specifies the type of training: pre (pre-training), sft (supervised fine-tuning), or dpo (direct preference optimization)')
        parser.add_argument('--attention_implementation', type=str, choices=['sdpa', 'flash_attention_2', 'eager'], help='Specifies attention implementation')
        parser.add_argument('--tensor_parallel_degree', type=int, help='Specifies the degree of tensor parallelism. Activates TP when it is greater than 1')
        parser.add_argument('--prompt', type=str, help='Prompt for generating text. Can also specify a file, use as: "FILE:prompt.txt"')
        parser.add_argument('--num_samples', type=int, help='Number of samples to generate')
        parser.add_argument('--max_new_tokens', type=int, help='Number of tokens to generate in each sample')
        parser.add_argument('--temperature', type=float, help='Temperature value for text generation')
        parser.add_argument('--top_k', type=int, help='Top k most likely tokens to be retained during text generation')

        args = parser.parse_args()
        
        if args.config:
            with open(args.config) as f:
                config = json.load(f)
            self.override_values(config)

        for arg_name, arg_value in vars(args).items():
            if arg_value is not None and hasattr(self, arg_name):
                setattr(self, arg_name, arg_value)

    def override_values(self, config_dict):
        modified = {}
        for k, v in config_dict.items():
            if hasattr(self, k) and getattr(self, k) != v:
                modified[k] = {"prev": getattr(self, k), "curr": v}
                setattr(self, k, v)
        return modified
    
    def should_override_config(self, iter_num):
        return self.config_override_check_interval is not None and \
            self.config_override_path is not None and \
            self.config_override_check_interval > 0 and \
            iter_num % self.config_override_check_interval == 0
            
    def override_config_properties(self):
        if os.path.exists(self.config_override_path):
            try:
                with open(self.config_override_path, "r", encoding="utf-8") as f:
                    config = json.load(f)
                modified = self.override_values(config)
                if modified:
                    logger.info(f"The following config properties were overridden: {modified}")
            except Exception as err:
                logger.warning(f"Unable to load override config. Error: {err}")