codeparrot_training.py 11.9 KB
Newer Older
1
import logging
2
3
import os
import time
4
5
6
7
8
9
10
11
12
13
from argparse import Namespace
from pathlib import Path

import datasets
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader

import transformers
14
from accelerate import Accelerator, DistributedType
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from arguments import TrainingArguments
from huggingface_hub import Repository
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed


class ConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
            infinite (bool): If True the iterator is reset after dataset reaches end else stops.
            seq_length (int): Length of token sequences to return.
            num_of_sequences: Number of token sequences to keep in buffer.
            chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
30
            tokenized: If true we use a pretokenized dataset.
31
32
33
    """

    def __init__(
34
35
36
37
38
39
40
41
        self,
        tokenizer,
        dataset,
        infinite=False,
        seq_length=1024,
        num_of_sequences=1024,
        chars_per_token=3.6,
        tokenized=False,
42
43
44
45
46
47
48
    ):
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.bos_token_id
        self.dataset = dataset
        self.seq_length = seq_length
        self.epoch = 0
        self.infinite = infinite
49
        self.current_size = 0
50
51
52
53
54
55
56
57
        self.tokenized = tokenized

        if self.tokenized:
            self.max_buffer_size = seq_length * num_of_sequences
            self.content_field = "input_ids"
        else:
            self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
            self.content_field = "content"
58
59
60
61
62
63
64

    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
65
                if buffer_len >= self.max_buffer_size:
66
67
                    break
                try:
68
                    buffer.append(next(iterator)[self.content_field])
69
70
71
72
73
74
75
76
77
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                        self.epoch += 1
                        logger.info(f"Dataset epoch: {self.epoch}")
                    else:
                        more_examples = False
                        break
78
79
80
81
            if self.tokenized:
                tokenized_inputs = buffer
            else:
                tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
82
83
84
85
86
87
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                all_token_ids.extend(tokenized_input + [self.concat_token_id])
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
88
                    self.current_size += 1
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
                    yield torch.tensor(input_ids)


def setup_logging(args):
    project_name = args.model_ckpt.split("/")[-1]
    logger = logging.getLogger(__name__)
    log_dir = Path(args.save_dir) / "log/"
    log_dir.mkdir(exist_ok=True)
    filename = f"debug_{accelerator.process_index}.log"
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
    )
    if accelerator.is_main_process:  # we only want to setup logging once
105
106
        accelerator.init_trackers(project_name, vars(args))
        run_name = accelerator.trackers[0].run.name
107
108
109
110
111
112
113
114
        logger.setLevel(logging.INFO)
        datasets.utils.logging.set_verbosity_info()
        transformers.utils.logging.set_verbosity_info()
    else:
        run_name = ""
        logger.setLevel(logging.ERROR)
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
115
    return logger, run_name
116
117
118
119
120
121
122


def create_dataloaders(args):
    ds_kwargs = {"streaming": True}
    train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
    train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
    valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
123
124
125
126
127
128
    train_dataset = ConstantLengthDataset(
        tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
    )
    valid_dataset = ConstantLengthDataset(
        tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
    )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
    eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
    return train_dataloader, eval_dataloader


def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": args.weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]


def log_metrics(step, metrics):
    logger.info(f"Step {step}: {metrics}")
    if accelerator.is_main_process:
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        accelerator.log(metrics, step)


def compute_tflops(elapsed_time, accelerator, args):
    # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
    config_model = accelerator.unwrap_model(model).config
    checkpoint_factor = 4 if args.gradient_checkpointing else 3
    batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
    factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
    flops_per_iteration = factor * (
        1.0
        + (args.seq_length / (6.0 * config_model.n_embd))
        + (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
    )
    tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12))
    return tflops
166
167
168
169
170
171
172
173
174
175
176
177


def evaluate(args):
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch, labels=batch)
        loss = outputs.loss.repeat(args.valid_batch_size)
        losses.append(accelerator.gather(loss))
        if args.max_eval_steps > 0 and step >= args.max_eval_steps:
            break
178
179
    losses = torch.cat(losses)
    loss = losses[: eval_dataloader.dataset.current_size].mean()
180
181
182
183
184
185
186
187
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()


# Accelerator
188
accelerator = Accelerator(log_with=["wandb", "tensorboard"])
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}

# Settings
parser = HfArgumentParser(TrainingArguments)
args = parser.parse_args()

args = Namespace(**vars(args), **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
set_seed(args.seed)

# Clone model repository
if accelerator.is_main_process:
    hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)

# Logging
204
logger, run_name = setup_logging(args)
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
logger.info(accelerator.state)

# Checkout new branch on repo
if accelerator.is_main_process:
    hf_repo.git_checkout(run_name, create_branch_ok=True)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
if args.gradient_checkpointing:
    model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)

# Load dataset and dataloader
train_dataloader, eval_dataloader = create_dataloaders(args)

# Prepare the optimizer and learning rate scheduler
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)
228
accelerator.register_for_checkpointing(lr_scheduler)
229
230
231
232
233
234
235
236
237
238
239


def get_lr():
    return optimizer.param_groups[0]["lr"]


# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# load in the weights and states from a previous save
if args.resume_from_checkpoint:
    if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
        accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
        accelerator.load_state(args.resume_from_checkpoint)
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)]
        dirs.sort(key=os.path.getctime)
        path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
    # Extract the step of the checkpoint to continue from there
    training_difference = os.path.splitext(path)[0]
    resume_step = int(training_difference.replace("step_", ""))

255
256
257
# Train model
model.train()
completed_steps = 0
258
t_start = time.time()
259
for step, batch in enumerate(train_dataloader, start=1):
260
261
    if args.resume_from_checkpoint and step < resume_step:
        continue  # we need to skip steps until we reach the resumed step
262
263
264
265
266
    loss = model(batch, labels=batch, use_cache=False).loss
    log_metrics(
        step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
    )
    loss = loss / args.gradient_accumulation_steps
267
268
269
270
271
272
273
274
275
    if step % args.gradient_accumulation_steps != 0:
        # Prevent backward from doing gradient all_reduce in every step
        if accelerator.distributed_type == DistributedType.MULTI_GPU:
            with model.no_sync():
                accelerator.backward(loss)
        else:
            accelerator.backward(loss)
    else:
        accelerator.backward(loss)
276
277
278
279
280
        accelerator.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        completed_steps += 1
281
282
283
284
        elapsed_time = time.time() - t_start
        tflops = compute_tflops(elapsed_time, accelerator, args)
        log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
        t_start = time.time()
285
286
287
288
289
    if step % args.save_checkpoint_steps == 0:
        logger.info("Evaluating and saving model checkpoint")
        eval_loss, perplexity = evaluate(args)
        log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
        accelerator.wait_for_everyone()
290
291
        save_dir = os.path.join(args.save_dir, f"step_{step}")
        accelerator.save_state(save_dir)
292
293
294
295
296
297
298
299
300
301
302
303
304
        if accelerator.is_main_process:
            hf_repo.push_to_hub(commit_message=f"step {step}")
        model.train()
    if completed_steps >= args.max_train_steps:
        break

# Evaluate and save the last checkpoint
logger.info("Evaluating and saving model after training")
eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
305
306
save_dir = os.path.join(args.save_dir, f"step_{step}")
accelerator.save_state(save_dir)
307
308
if accelerator.is_main_process:
    hf_repo.push_to_hub(commit_message="final model")