cal_lr.py 3.79 KB
Newer Older
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
1
# coding=utf-8
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19

import math
chenych's avatar
chenych committed
20
from typing import Literal
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
22
23
24
25
26
27

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq

luopl's avatar
luopl committed
28
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
chenych's avatar
chenych committed
29
30
31
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33
34
35
36
37
38
39
40


BASE_LR = 3e-4  # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000  # from llama paper


def calculate_lr(
    model_name_or_path: str,
    batch_size: int,  # total batch size, namely (batch size * gradient accumulation * world size)
chenych's avatar
chenych committed
41
    stage: Literal["pt", "sft"] = "sft",
luopl's avatar
luopl committed
42
    dataset: str = "alpaca_en_demo",
chenych's avatar
chenych committed
43
44
45
46
47
    dataset_dir: str = "data",
    template: str = "default",
    cutoff_len: int = 1024,  # i.e. maximum input length during training
    is_mistral_or_gemma: bool = False,  # mistral and gemma models opt for a smaller learning rate,
    packing: bool = False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48
):
chenych's avatar
chenych committed
49
50
    r"""
    Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
luopl's avatar
luopl committed
51
52
    Usage:
    python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
chenych's avatar
chenych committed
53
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
54
55
56
57
58
59
60
61
    model_args, data_args, training_args, _, _ = get_train_args(
        dict(
            stage=stage,
            model_name_or_path=model_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=cutoff_len,
chenych's avatar
chenych committed
62
            packing=packing,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
63
64
            output_dir="dummy_dir",
            overwrite_cache=True,
chenych's avatar
chenych committed
65
            do_train=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
66
67
        )
    )
chenych's avatar
chenych committed
68
69
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
luopl's avatar
luopl committed
70
71
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
72
73
74
75
76
    if stage == "pt":
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    elif stage == "sft":
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
    else:
chenych's avatar
chenych committed
77
        raise NotImplementedError("Stage does not supported: {}.".format(stage))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
78

chenych's avatar
chenych committed
79
    dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80
81
82
83
84
85
86
87
88
    valid_tokens, total_tokens = 0, 0
    for batch in tqdm(dataloader):
        valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
        total_tokens += torch.numel(batch["labels"])

    batch_max_len = cutoff_len * batch_size  # max tokens in a batch
    valid_ratio = valid_tokens / total_tokens
    batch_valid_len = batch_max_len * valid_ratio
    lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)  # lr ~ sqrt(batch_size)
chenych's avatar
chenych committed
89
    lr = lr / 6.0 if is_mistral_or_gemma else lr
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
90
91
92
93
94
95
96
97
98
    print(
        "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
            lr, valid_ratio * 100, batch_valid_len
        )
    )


if __name__ == "__main__":
    fire.Fire(calculate_lr)