"csrc/ktransformers_ext/CMakeLists.txt" did not exist on "19fd24f46176ca490e9ea4aaa34be44d64f3510c"
cal_lr.py 3.77 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
17
18

import math
chenych's avatar
chenych committed
19
from typing import Literal
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22
23
24
25
26

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq

luopl's avatar
luopl committed
27
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
chenych's avatar
chenych committed
28
29
30
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33
34
35
36
37
38
39


BASE_LR = 3e-4  # 1.5e-4 for 30B-70B models
BASE_BS = 4_000_000  # from llama paper


def calculate_lr(
    model_name_or_path: str,
    batch_size: int,  # total batch size, namely (batch size * gradient accumulation * world size)
chenych's avatar
chenych committed
40
    stage: Literal["pt", "sft"] = "sft",
luopl's avatar
luopl committed
41
    dataset: str = "alpaca_en_demo",
chenych's avatar
chenych committed
42
43
44
45
46
    dataset_dir: str = "data",
    template: str = "default",
    cutoff_len: int = 1024,  # i.e. maximum input length during training
    is_mistral_or_gemma: bool = False,  # mistral and gemma models opt for a smaller learning rate,
    packing: bool = False,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
47
):
chenych's avatar
chenych committed
48
49
    r"""
    Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
luopl's avatar
luopl committed
50
51
    Usage:
    python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
chenych's avatar
chenych committed
52
    """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
53
54
55
56
57
58
59
60
    model_args, data_args, training_args, _, _ = get_train_args(
        dict(
            stage=stage,
            model_name_or_path=model_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=cutoff_len,
chenych's avatar
chenych committed
61
            packing=packing,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
63
            output_dir="dummy_dir",
            overwrite_cache=True,
chenych's avatar
chenych committed
64
            do_train=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
65
66
        )
    )
chenych's avatar
chenych committed
67
68
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
luopl's avatar
luopl committed
69
70
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74
75
    if stage == "pt":
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    elif stage == "sft":
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
    else:
luopl's avatar
luopl committed
76
        raise NotImplementedError(f"Stage does not supported: {stage}.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77

chenych's avatar
chenych committed
78
    dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
79
80
81
82
83
84
85
86
87
    valid_tokens, total_tokens = 0, 0
    for batch in tqdm(dataloader):
        valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
        total_tokens += torch.numel(batch["labels"])

    batch_max_len = cutoff_len * batch_size  # max tokens in a batch
    valid_ratio = valid_tokens / total_tokens
    batch_valid_len = batch_max_len * valid_ratio
    lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)  # lr ~ sqrt(batch_size)
chenych's avatar
chenych committed
88
    lr = lr / 6.0 if is_mistral_or_gemma else lr
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
89
90
91
92
93
94
95
96
97
    print(
        "Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
            lr, valid_ratio * 100, batch_valid_len
        )
    )


if __name__ == "__main__":
    fire.Fire(calculate_lr)