cal_lr.py 3.64 KB
Newer Older
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
1
# coding=utf-8
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
19

import math
chenych's avatar
chenych committed
20
from typing import Literal
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
22
23
24
25
26
27

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq

chenych's avatar
chenych committed
28
29
30
31
from llamafactory.data import get_dataset
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33
34
35
36
37
38
39
40


BASE_LR = 3e-4  # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000  # from llama paper


def calculate_lr(
    model_name_or_path: str,
    batch_size: int,  # total batch size, namely (batch size * gradient accumulation * world size)
chenych's avatar
chenych committed
41
42
43
44
45
    stage: Literal["pt", "sft"] = "sft",
    dataset: str = "alpaca_en",
    dataset_dir: str = "data",
    template: str = "default",
    cutoff_len: int = 1024,  # i.e. maximum input length during training
chenych's avatar
chenych committed
46
    is_mistral: bool = False,  # mistral model uses a smaller learning rate,
chenych's avatar
chenych committed
47
    packing: bool = False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48
):
chenych's avatar
chenych committed
49
50
51
52
    r"""
    Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
    Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
54
55
56
57
58
59
60
    model_args, data_args, training_args, _, _ = get_train_args(
        dict(
            stage=stage,
            model_name_or_path=model_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=cutoff_len,
chenych's avatar
chenych committed
61
            packing=packing,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
63
            output_dir="dummy_dir",
            overwrite_cache=True,
chenych's avatar
chenych committed
64
            do_train=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
65
66
        )
    )
chenych's avatar
chenych committed
67
68
69
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71
72
73
74
    if stage == "pt":
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    elif stage == "sft":
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
    else:
chenych's avatar
chenych committed
75
        raise NotImplementedError("Stage does not supported: {}.".format(stage))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
76

chenych's avatar
chenych committed
77
    dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
78
79
80
81
82
83
84
85
86
    valid_tokens, total_tokens = 0, 0
    for batch in tqdm(dataloader):
        valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
        total_tokens += torch.numel(batch["labels"])

    batch_max_len = cutoff_len * batch_size  # max tokens in a batch
    valid_ratio = valid_tokens / total_tokens
    batch_valid_len = batch_max_len * valid_ratio
    lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)  # lr ~ sqrt(batch_size)
chenych's avatar
chenych committed
87
    lr = lr / 6.0 if is_mistral else lr
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
88
89
90
91
92
93
94
95
96
    print(
        "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
            lr, valid_ratio * 100, batch_valid_len
        )
    )


if __name__ == "__main__":
    fire.Fire(calculate_lr)