Unverified Commit d9184131 authored by Loubna Ben Allal's avatar Loubna Ben Allal Committed by GitHub
Browse files

New features for CodeParrot training script (#16851)



* add tflops logging and fix grad accumulation

* add accelerate tracking and checkpointing

* scale loss of last batch correctly

* fix typo

* compress loss computation
Co-authored-by: default avatarLeandro von Werra <lvwerra@users.noreply.github.com>

* add resume from checkpoint argument

* add load_state accelerate from checkpoint, register lr scheduler and add tflops function

* reformat code

* reformat code

* add condition on path for resume checkpoint

* combine if conditions
Co-authored-by: default avatarLeandro von Werra <lvwerra@users.noreply.github.com>

* add source for tflops formula
Co-authored-by: default avatarLeandro von Werra <lvwerra@users.noreply.github.com>
parent eef2422e
......@@ -82,7 +82,7 @@ Now that the dataset, tokenizer, and model are ready we can start training the m
First you need to configure `accelerate` and login to Weights & Biases:
```bash
acclerate config
accelerate config
wandb login
```
......
......@@ -49,6 +49,10 @@ class TrainingArguments:
default=1024,
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
)
resume_from_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "States path if the training should continue from a checkpoint folder."},
)
@dataclass
......
import logging
import os
import time
from argparse import Namespace
from pathlib import Path
......@@ -7,11 +9,9 @@ import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
import transformers
import wandb
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from arguments import TrainingArguments
from huggingface_hub import Repository
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
......@@ -39,6 +39,7 @@ class ConstantLengthDataset(IterableDataset):
self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0
self.infinite = infinite
self.current_size = 0
def __iter__(self):
iterator = iter(self.dataset)
......@@ -66,6 +67,7 @@ class ConstantLengthDataset(IterableDataset):
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield torch.tensor(input_ids)
......@@ -82,20 +84,17 @@ def setup_logging(args):
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
)
if accelerator.is_main_process: # we only want to setup logging once
wandb.init(project=project_name, config=args)
run_name = wandb.run.name
tb_writer = SummaryWriter()
tb_writer.add_hparams(vars(args), {"0": 0})
accelerator.init_trackers(project_name, vars(args))
run_name = accelerator.trackers[0].run.name
logger.setLevel(logging.INFO)
datasets.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_info()
else:
tb_writer = None
run_name = ""
logger.setLevel(logging.ERROR)
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
return logger, tb_writer, run_name
return logger, run_name
def create_dataloaders(args):
......@@ -126,8 +125,22 @@ def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
def log_metrics(step, metrics):
logger.info(f"Step {step}: {metrics}")
if accelerator.is_main_process:
wandb.log(metrics)
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
accelerator.log(metrics, step)
def compute_tflops(elapsed_time, accelerator, args):
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
config_model = accelerator.unwrap_model(model).config
checkpoint_factor = 4 if args.gradient_checkpointing else 3
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
flops_per_iteration = factor * (
1.0
+ (args.seq_length / (6.0 * config_model.n_embd))
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
)
tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12))
return tflops
def evaluate(args):
......@@ -140,7 +153,8 @@ def evaluate(args):
losses.append(accelerator.gather(loss))
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
break
loss = torch.mean(torch.cat(losses))
losses = torch.cat(losses)
loss = losses[: eval_dataloader.dataset.current_size].mean()
try:
perplexity = torch.exp(loss)
except OverflowError:
......@@ -149,7 +163,7 @@ def evaluate(args):
# Accelerator
accelerator = Accelerator()
accelerator = Accelerator(log_with=["wandb", "tensorboard"])
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
# Settings
......@@ -165,7 +179,7 @@ if accelerator.is_main_process:
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
# Logging
logger, tb_writer, run_name = setup_logging(args)
logger, run_name = setup_logging(args)
logger.info(accelerator.state)
# Checkout new branch on repo
......@@ -189,6 +203,7 @@ lr_scheduler = get_scheduler(
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
accelerator.register_for_checkpointing(lr_scheduler)
def get_lr():
......@@ -200,29 +215,58 @@ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
# load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
# Extract the step of the checkpoint to continue from there
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
# Train model
model.train()
completed_steps = 0
t_start = time.time()
for step, batch in enumerate(train_dataloader, start=1):
if args.resume_from_checkpoint and step < resume_step:
continue # we need to skip steps until we reach the resumed step
loss = model(batch, labels=batch, use_cache=False).loss
log_metrics(
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0:
if step % args.gradient_accumulation_steps != 0:
# Prevent backward from doing gradient all_reduce in every step
if accelerator.distributed_type == DistributedType.MULTI_GPU:
with model.no_sync():
accelerator.backward(loss)
else:
accelerator.backward(loss)
else:
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
completed_steps += 1
elapsed_time = time.time() - t_start
tflops = compute_tflops(elapsed_time, accelerator, args)
log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
t_start = time.time()
if step % args.save_checkpoint_steps == 0:
logger.info("Evaluating and saving model checkpoint")
eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
save_dir = os.path.join(args.save_dir, f"step_{step}")
accelerator.save_state(save_dir)
if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message=f"step {step}")
model.train()
......@@ -236,5 +280,7 @@ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
save_dir = os.path.join(args.save_dir, f"step_{step}")
accelerator.save_state(save_dir)
if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message="final model")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment