Unverified Commit d9184131 authored by Loubna Ben Allal's avatar Loubna Ben Allal Committed by GitHub
Browse files

New features for CodeParrot training script (#16851)



* add tflops logging and fix grad accumulation

* add accelerate tracking and checkpointing

* scale loss of last batch correctly

* fix typo

* compress loss computation
Co-authored-by: default avatarLeandro von Werra <lvwerra@users.noreply.github.com>

* add resume from checkpoint argument

* add load_state accelerate from checkpoint, register lr scheduler and add tflops function

* reformat code

* reformat code

* add condition on path for resume checkpoint

* combine if conditions
Co-authored-by: default avatarLeandro von Werra <lvwerra@users.noreply.github.com>

* add source for tflops formula
Co-authored-by: default avatarLeandro von Werra <lvwerra@users.noreply.github.com>
parent eef2422e
...@@ -82,7 +82,7 @@ Now that the dataset, tokenizer, and model are ready we can start training the m ...@@ -82,7 +82,7 @@ Now that the dataset, tokenizer, and model are ready we can start training the m
First you need to configure `accelerate` and login to Weights & Biases: First you need to configure `accelerate` and login to Weights & Biases:
```bash ```bash
acclerate config accelerate config
wandb login wandb login
``` ```
......
...@@ -49,6 +49,10 @@ class TrainingArguments: ...@@ -49,6 +49,10 @@ class TrainingArguments:
default=1024, default=1024,
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."}, metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
) )
resume_from_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "States path if the training should continue from a checkpoint folder."},
)
@dataclass @dataclass
......
import logging import logging
import os
import time
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
...@@ -7,11 +9,9 @@ import torch ...@@ -7,11 +9,9 @@ import torch
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
import transformers import transformers
import wandb from accelerate import Accelerator, DistributedType
from accelerate import Accelerator
from arguments import TrainingArguments from arguments import TrainingArguments
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
...@@ -39,6 +39,7 @@ class ConstantLengthDataset(IterableDataset): ...@@ -39,6 +39,7 @@ class ConstantLengthDataset(IterableDataset):
self.input_characters = seq_length * chars_per_token * num_of_sequences self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0 self.epoch = 0
self.infinite = infinite self.infinite = infinite
self.current_size = 0
def __iter__(self): def __iter__(self):
iterator = iter(self.dataset) iterator = iter(self.dataset)
...@@ -66,6 +67,7 @@ class ConstantLengthDataset(IterableDataset): ...@@ -66,6 +67,7 @@ class ConstantLengthDataset(IterableDataset):
for i in range(0, len(all_token_ids), self.seq_length): for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length] input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length: if len(input_ids) == self.seq_length:
self.current_size += 1
yield torch.tensor(input_ids) yield torch.tensor(input_ids)
...@@ -82,20 +84,17 @@ def setup_logging(args): ...@@ -82,20 +84,17 @@ def setup_logging(args):
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()], handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
) )
if accelerator.is_main_process: # we only want to setup logging once if accelerator.is_main_process: # we only want to setup logging once
wandb.init(project=project_name, config=args) accelerator.init_trackers(project_name, vars(args))
run_name = wandb.run.name run_name = accelerator.trackers[0].run.name
tb_writer = SummaryWriter()
tb_writer.add_hparams(vars(args), {"0": 0})
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
datasets.utils.logging.set_verbosity_info() datasets.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
else: else:
tb_writer = None
run_name = "" run_name = ""
logger.setLevel(logging.ERROR) logger.setLevel(logging.ERROR)
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
return logger, tb_writer, run_name return logger, run_name
def create_dataloaders(args): def create_dataloaders(args):
...@@ -126,8 +125,22 @@ def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]): ...@@ -126,8 +125,22 @@ def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
def log_metrics(step, metrics): def log_metrics(step, metrics):
logger.info(f"Step {step}: {metrics}") logger.info(f"Step {step}: {metrics}")
if accelerator.is_main_process: if accelerator.is_main_process:
wandb.log(metrics) accelerator.log(metrics, step)
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
def compute_tflops(elapsed_time, accelerator, args):
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
config_model = accelerator.unwrap_model(model).config
checkpoint_factor = 4 if args.gradient_checkpointing else 3
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
flops_per_iteration = factor * (
1.0
+ (args.seq_length / (6.0 * config_model.n_embd))
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
)
tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12))
return tflops
def evaluate(args): def evaluate(args):
...@@ -140,7 +153,8 @@ def evaluate(args): ...@@ -140,7 +153,8 @@ def evaluate(args):
losses.append(accelerator.gather(loss)) losses.append(accelerator.gather(loss))
if args.max_eval_steps > 0 and step >= args.max_eval_steps: if args.max_eval_steps > 0 and step >= args.max_eval_steps:
break break
loss = torch.mean(torch.cat(losses)) losses = torch.cat(losses)
loss = losses[: eval_dataloader.dataset.current_size].mean()
try: try:
perplexity = torch.exp(loss) perplexity = torch.exp(loss)
except OverflowError: except OverflowError:
...@@ -149,7 +163,7 @@ def evaluate(args): ...@@ -149,7 +163,7 @@ def evaluate(args):
# Accelerator # Accelerator
accelerator = Accelerator() accelerator = Accelerator(log_with=["wandb", "tensorboard"])
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
# Settings # Settings
...@@ -165,7 +179,7 @@ if accelerator.is_main_process: ...@@ -165,7 +179,7 @@ if accelerator.is_main_process:
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt) hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
# Logging # Logging
logger, tb_writer, run_name = setup_logging(args) logger, run_name = setup_logging(args)
logger.info(accelerator.state) logger.info(accelerator.state)
# Checkout new branch on repo # Checkout new branch on repo
...@@ -189,6 +203,7 @@ lr_scheduler = get_scheduler( ...@@ -189,6 +203,7 @@ lr_scheduler = get_scheduler(
num_warmup_steps=args.num_warmup_steps, num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps, num_training_steps=args.max_train_steps,
) )
accelerator.register_for_checkpointing(lr_scheduler)
def get_lr(): def get_lr():
...@@ -200,29 +215,58 @@ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( ...@@ -200,29 +215,58 @@ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader model, optimizer, train_dataloader, eval_dataloader
) )
# load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
# Extract the step of the checkpoint to continue from there
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
# Train model # Train model
model.train() model.train()
completed_steps = 0 completed_steps = 0
t_start = time.time()
for step, batch in enumerate(train_dataloader, start=1): for step, batch in enumerate(train_dataloader, start=1):
if args.resume_from_checkpoint and step < resume_step:
continue # we need to skip steps until we reach the resumed step
loss = model(batch, labels=batch, use_cache=False).loss loss = model(batch, labels=batch, use_cache=False).loss
log_metrics( log_metrics(
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()} step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
) )
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss) if step % args.gradient_accumulation_steps != 0:
if step % args.gradient_accumulation_steps == 0: # Prevent backward from doing gradient all_reduce in every step
if accelerator.distributed_type == DistributedType.MULTI_GPU:
with model.no_sync():
accelerator.backward(loss)
else:
accelerator.backward(loss)
else:
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0) accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
completed_steps += 1 completed_steps += 1
elapsed_time = time.time() - t_start
tflops = compute_tflops(elapsed_time, accelerator, args)
log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
t_start = time.time()
if step % args.save_checkpoint_steps == 0: if step % args.save_checkpoint_steps == 0:
logger.info("Evaluating and saving model checkpoint") logger.info("Evaluating and saving model checkpoint")
eval_loss, perplexity = evaluate(args) eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) save_dir = os.path.join(args.save_dir, f"step_{step}")
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) accelerator.save_state(save_dir)
if accelerator.is_main_process: if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message=f"step {step}") hf_repo.push_to_hub(commit_message=f"step {step}")
model.train() model.train()
...@@ -236,5 +280,7 @@ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity}) ...@@ -236,5 +280,7 @@ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save) unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
save_dir = os.path.join(args.save_dir, f"step_{step}")
accelerator.save_state(save_dir)
if accelerator.is_main_process: if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message="final model") hf_repo.push_to_hub(commit_message="final model")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment