utils.py 6.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import time
import sys
import IPython

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader

13
14
15
16
17
18
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoConfig,
)
19
20
21
22
23
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from accelerate import Accelerator
from accelerate.utils.dataclasses import FP8RecipeKwargs

24

25
26
27
class HyperParameters:
    def __init__(self):
        self.mixed_precision = "bf16"
28
        # self.model_name = "" # <== Add model weight location here
29
30
31
32
33
34
        self.dataset_name = "timdettmers/openassistant-guanaco"
        self.dataset_text_field = "text"
        self.learning_rate = 1.41e-5
        self.batch_size = 8
        self.max_seq_length = 256
        self.gradient_accumulation_steps = 1
35
36
37
        self.num_warmup_steps = 5
        self.num_training_steps = 10

38
39
40

hyperparams = HyperParameters()

41
42

def get_dataloaders(accelerator: Accelerator, hyperparams):
43
44
45
46
47
48
49
50
51
52
53
54
    dataset = load_dataset(hyperparams.dataset_name, split="train")
    tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
    if getattr(tokenizer, "pad_token", None) is None:
        tokenizer.pad_token = tokenizer.eos_token

    def tokenize(element):
        outputs = tokenizer(
            element["text"],
            truncation=True,
            padding=False,
            max_length=hyperparams.max_seq_length,
            return_overflowing_tokens=False,
55
            return_length=False,
56
57
58
59
        )
        return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}

    with accelerator.main_process_first():
60
        dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    # Simply pad to the multiple of 16 for both FP8 and BF16 precision
    pad_to_multiple_of = 16
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        pad_to_multiple_of=pad_to_multiple_of,
    )

    dataloader_params = {
        "batch_size": hyperparams.batch_size,
        "collate_fn": data_collator,
        "drop_last": True,
    }
    train_dataloader = DataLoader(dataset, **dataloader_params)
    return train_dataloader

78

79
80
81
82
83
84
85
86
87
88
def init_baseline_model(hyperparams):
    # Init the model
    config = AutoConfig.from_pretrained(hyperparams.model_name)
    # make sure to use flash_attention to do iso comparison with TELlamaModel
    config._attn_implementation = "flash_attention_2"
    model = AutoModelForCausalLM.from_pretrained(
        hyperparams.model_name,
        config=config,
        torch_dtype=torch.bfloat16,
    )
89
    model = model.cuda()
90
    # Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
91
    model.config.use_cache = False
92
93
94

    return model

95

96
97
98
def init_te_llama_model(hyperparams):
    # Init the model
    from te_llama import TELlamaForCausalLM
99

100
    config = AutoConfig.from_pretrained(hyperparams.model_name)
101
    config._attn_implementation = "flash_attention_2"
102
    model = TELlamaForCausalLM.from_pretrained_local(
103
104
105
        hyperparams.model_name,
        config=config,
        torch_dtype=torch.bfloat16,
106
    )
107
    model = model.cuda()
108
    # Needed for the cases when using TELlamaForCausalLM
109
    model.config.use_cache = False
110
111
112

    return model

113

114
115
def wrap_with_accelerator(model, hyperparams):
    # Create FP8 kwarg handler if required
116
117
118
    fp8_kwarg_handler = (
        [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
    )
119
120
121
122
123
124

    # Init HF accelerator that's used for training
    accelerator = Accelerator(
        log_with="wandb",
        gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
        mixed_precision=hyperparams.mixed_precision,
125
        kwargs_handlers=fp8_kwarg_handler,
126
    )
127
    # accelerator.print(f'State: {accelerator.state}')
128
129
130
    train_dataloader = get_dataloaders(accelerator, hyperparams)

    # Wrap model, optimizer/scheduler, dataloaders in accelerate
131
    optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True)
132
133
134
135
136
137
138
139
140
141
142
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=hyperparams.num_training_steps,
    )
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    return accelerator, model, optimizer, train_dataloader, lr_scheduler

143

144
145
146
147
148
149
def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    train_dataloader = enumerate(train_dataloader)

150
151
    # Warmup iters
    for _ in range(hyperparams.num_warmup_steps):
152
153
154
155
156
157
158
159
160
161
        step, batch = next(train_dataloader)
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

162
163
164
165
    # Get the timers ready
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
    start.record()
    # Training iters
    for _ in range(hyperparams.num_training_steps):
        step, batch = next(train_dataloader)
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
    torch.cuda.synchronize()
    end.record()
181
182
    accelerator.end_training()

183
184
185
186
187
    print(
        f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step:"
        f" {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds"
    )

188
189
190
191
192
193
194
195

def restart_jupyter_notebook():
    # Try restarting the Jupyter kernel
    IPython.Application.instance().kernel.do_shutdown(True)

    # Check whether the device memory has been flushed
    if torch.cuda.memory_allocated() != 0:
        import warnings
196

197
198
199
200
201
        warnings.warn("The device memory hasn't been flushed, trying with a second method!")

        # Try restarting the Jupyter kernel another way
        # Restart the kernel
        from IPython.core.display import HTML
202

203
204
205
        HTML("<script>Jupyter.notebook.kernel.restart()</script>")

        if torch.cuda.memory_allocated() != 0:
206
207
208
            print(
                "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
            )
209
210
211
212

    # Suppress the warnings
    if not sys.warnoptions:
        import warnings
213

214
215
        warnings.simplefilter("ignore")
        torch.set_warn_always(False)