utils.py 8.26 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

import time
import sys
import IPython

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader

13
14
15
16
17
18
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoConfig,
)
19
20
21
22
23
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from accelerate import Accelerator
from accelerate.utils.dataclasses import FP8RecipeKwargs

24

25
26
27
class HyperParameters:
    def __init__(self):
        self.mixed_precision = "bf16"
28
29
30
31

        # Set to Meta Llama 2 by default.
        self.model_name = "meta-llama/Llama-2-7b-hf"

32
33
34
35
36
37
        self.dataset_name = "timdettmers/openassistant-guanaco"
        self.dataset_text_field = "text"
        self.learning_rate = 1.41e-5
        self.batch_size = 8
        self.max_seq_length = 256
        self.gradient_accumulation_steps = 1
38
39
40
        self.num_warmup_steps = 5
        self.num_training_steps = 10

41
42
43
44
        # This is either provided by the user or it will be set when the
        # model weights are downloaded.
        self.weights_cache_dir = ""

45
46
47

hyperparams = HyperParameters()

48
49

def get_dataloaders(accelerator: Accelerator, hyperparams):
50
51
52
53
54
55
56
57
58
59
60
61
    dataset = load_dataset(hyperparams.dataset_name, split="train")
    tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
    if getattr(tokenizer, "pad_token", None) is None:
        tokenizer.pad_token = tokenizer.eos_token

    def tokenize(element):
        outputs = tokenizer(
            element["text"],
            truncation=True,
            padding=False,
            max_length=hyperparams.max_seq_length,
            return_overflowing_tokens=False,
62
            return_length=False,
63
64
65
66
        )
        return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}

    with accelerator.main_process_first():
67
        dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    # Simply pad to the multiple of 16 for both FP8 and BF16 precision
    pad_to_multiple_of = 16
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        pad_to_multiple_of=pad_to_multiple_of,
    )

    dataloader_params = {
        "batch_size": hyperparams.batch_size,
        "collate_fn": data_collator,
        "drop_last": True,
    }
    train_dataloader = DataLoader(dataset, **dataloader_params)
    return train_dataloader

85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def ensure_model_is_downloaded(hyperparams):
    assert hyperparams.model_name in [
        "meta-llama/Meta-Llama-3-8B",
        "meta-llama/Llama-2-7b-hf",
    ], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!"

    # Login using Huggingface Hub API
    from huggingface_hub import login

    try:
        login(hyperparams.hf_access_token)
    except Exception as e:
        if "Invalid token passed!" in str(e):
            print(
                "Please pass a valid HF Access Token! More info at"
                " https://huggingface.co/docs/hub/en/security-tokens."
            )
        else:
            print(f"Exception is {e}")

    # Download the model if it doesn't exist
    from huggingface_hub import snapshot_download

    supplied_cache_dir = (
        hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
    )
    hyperparams.weights_cache_dir = snapshot_download(
        repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
    )

    print(f"Model cache directory : {hyperparams.weights_cache_dir}")


119
def init_baseline_model(hyperparams):
120
121
122
    # Download and cache the weights
    ensure_model_is_downloaded(hyperparams)

123
    # Init the model
124
    config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
125
126
127
    # make sure to use flash_attention to do iso comparison with TELlamaModel
    config._attn_implementation = "flash_attention_2"
    model = AutoModelForCausalLM.from_pretrained(
128
        hyperparams.weights_cache_dir,
129
130
131
        config=config,
        torch_dtype=torch.bfloat16,
    )
132
    model = model.cuda()
133
    # Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
134
    model.config.use_cache = False
135
136
137

    return model

138

139
def init_te_llama_model(hyperparams):
140
141
142
    # Download and cache the weights
    ensure_model_is_downloaded(hyperparams)

143
144
    # Init the model
    from te_llama import TELlamaForCausalLM
145

146
    config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
147
    config._attn_implementation = "flash_attention_2"
148
    model = TELlamaForCausalLM.from_pretrained_local(
149
        hyperparams.weights_cache_dir,
150
151
        config=config,
        torch_dtype=torch.bfloat16,
152
    )
153
    model = model.cuda()
154
    # Needed for the cases when using TELlamaForCausalLM
155
    model.config.use_cache = False
156
157
158

    return model

159

160
161
def wrap_with_accelerator(model, hyperparams):
    # Create FP8 kwarg handler if required
162
163
164
    fp8_kwarg_handler = (
        [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
    )
165
166
167
168
169
170

    # Init HF accelerator that's used for training
    accelerator = Accelerator(
        log_with="wandb",
        gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
        mixed_precision=hyperparams.mixed_precision,
171
        kwargs_handlers=fp8_kwarg_handler,
172
    )
173
    # accelerator.print(f'State: {accelerator.state}')
174
175
176
    train_dataloader = get_dataloaders(accelerator, hyperparams)

    # Wrap model, optimizer/scheduler, dataloaders in accelerate
177
    optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True)
178
179
180
181
182
183
184
185
186
187
188
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=hyperparams.num_training_steps,
    )
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    return accelerator, model, optimizer, train_dataloader, lr_scheduler

189

190
191
192
193
194
195
def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    train_dataloader = enumerate(train_dataloader)

196
197
    # Warmup iters
    for _ in range(hyperparams.num_warmup_steps):
198
199
200
201
202
203
204
205
206
207
        step, batch = next(train_dataloader)
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

208
209
210
211
    # Get the timers ready
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
    start.record()
    # Training iters
    for _ in range(hyperparams.num_training_steps):
        step, batch = next(train_dataloader)
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
    torch.cuda.synchronize()
    end.record()
227
228
    accelerator.end_training()

229
230
231
232
233
    print(
        f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step:"
        f" {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds"
    )

234
235
236
237
238
239
240
241

def restart_jupyter_notebook():
    # Try restarting the Jupyter kernel
    IPython.Application.instance().kernel.do_shutdown(True)

    # Check whether the device memory has been flushed
    if torch.cuda.memory_allocated() != 0:
        import warnings
242

243
244
245
246
247
        warnings.warn("The device memory hasn't been flushed, trying with a second method!")

        # Try restarting the Jupyter kernel another way
        # Restart the kernel
        from IPython.core.display import HTML
248

249
250
251
        HTML("<script>Jupyter.notebook.kernel.restart()</script>")

        if torch.cuda.memory_allocated() != 0:
252
253
254
            print(
                "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
            )
255
256
257
258

    # Suppress the warnings
    if not sys.warnoptions:
        import warnings
259

260
261
        warnings.simplefilter("ignore")
        torch.set_warn_always(False)