Commit 0bf5e500 authored by Tri Dao's avatar Tri Dao
Browse files

Release training code

parent 9bc63d1e
from typing import Callable
import dotenv
import hydra
from omegaconf import OmegaConf, DictConfig
# load environment variables from `.env` file if it exists
# recursively searches for `.env` in all folders starting from work dir
dotenv.load_dotenv(override=True)
OmegaConf.register_new_resolver('eval', eval)
OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
# Delay the evaluation until we have the datamodule
# So we want the resolver to yield the same string.
OmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}')
# Turn on TensorFloat32
import torch.backends
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig:
"""Only keep keys where fn(key) is True. Support nested DictConfig.
"""
# Using d.items_ex(resolve=False) instead of d.items() since we want to keep the
# ${datamodule:foo} unresolved for now.
return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v
# for k, v in d.items_ex(resolve=False) if fn(k)})
for k, v in d.items() if fn(k)})
@hydra.main(config_path="configs/", config_name="config.yaml")
def main(config: DictConfig):
# Remove config keys that start with '__'. These are meant to be used only in computing
# other entries in the config.
config = dictconfig_filter_key(config, lambda k: not k.startswith('__'))
# Imports should be nested inside @hydra.main to optimize tab completion
# Read more here: https://github.com/facebookresearch/hydra/issues/934
from src.train import train
from src.eval import evaluate
from src.utils import utils
# A couple of optional utilities:
# - disabling python warnings
# - forcing debug-friendly configuration
# - verifying experiment name is set when running in experiment mode
# You can safely get rid of this line if you don't want those
utils.extras(config)
# Pretty print config using Rich library
if config.get("print_config"):
utils.print_config(config, resolve=True)
# Train model
mode = config.get('mode', 'train')
if mode not in ['train', 'eval']:
raise NotImplementedError(f'mode {mode} not supported')
if mode == 'train':
return train(config)
elif mode == 'eval':
return evaluate(config)
if __name__ == "__main__":
main()
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_only
import torch
from torch.autograd import grad
class CausalityMonitor(Callback):
r"""Monitor causality of a model by tracking gradient leakage forward in time.
In a fully causal model, dy[k]du[s] ~= 0 for all k < s.
Args:
seq_len (int): Length of the sequence to monitor.
input_dim (int): Dimension of the input to monitor. If 0, the callback assumes
the task to be language modeling, and skips the embedding layer. If > 0,
input_dim is interpreted as the input channel dimension, i.e. D with
dummy input of dimension [B, L, D].
Notes:
This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute,
indicating the primary model to monitor. For LMs, `net` or `s4seq` should
be after the embedding layer.
"""
def __init__(self, seq_len: int = 10, input_dim: int = 0):
super().__init__()
self.seq_len = seq_len
self.input_dim = input_dim
@rank_zero_only
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
model = pl_module.model
with torch.enable_grad():
if self.input_dim == 0:
# [MP] LongTensors cannot have gradients - we start from post
# embedding in the LM case
input_dim = model.d_model
x = torch.randn((2, self.seq_len, input_dim), \
requires_grad=True).to(pl_module.device)
# [DF] HACK: we need to get the layer that comes after the embedding
if hasattr(model, 'net'):
y = model.net(x)
else:
y = model.s4seq(x)
else:
x = torch.randn(1, self.seq_len, self.input_dim, \
requires_grad=True).to(pl_module.device)
y = model(x)
stats = {}
for i in range(self.seq_len):
# total gradients flowing from y_i to x
g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0]
g = g[0,i+1:,:].abs().mean()
stats[f'stats/causality_{i}'] = g.item()
if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.global_step)
# Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py
# https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py
# https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2
# https://github.com/PyTorchLightning/pytorch-lightning/issues/8100
from typing import Dict, Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
from src.utils.ema import ExponentialMovingAverage
class EMACallback(Callback):
"""TD [2021-08-31]: saving and loading from checkpoint should work.
"""
def __init__(self, decay: float, use_num_updates: bool = True):
"""
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
super().__init__()
self.decay = decay
self.use_num_updates = use_num_updates
self.ema = None
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
# It's possible that we already loaded EMA from the checkpoint
if self.ema is None:
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
decay=self.decay, use_num_updates=self.use_num_updates)
# Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it
# We only want to update when parameters are changing.
# Because of gradient accumulation, this doesn't happen every training step.
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11688
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if (batch_idx + 1) % trainer.accumulate_grad_batches == 0:
self.ema.update()
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# During the initial validation we don't have self.ema yet
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.restore()
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.restore()
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
return self.ema.state_dict()
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
if self.ema is None:
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
decay=self.decay, use_num_updates=self.use_num_updates)
self.ema.load_state_dict(checkpoint)
# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py
from typing import Any, List, Sequence
import torch
from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling
from src.utils.flops import profile_deepspeed, profile_fvcore
class FlopCount(Callback):
"""Counter the number of FLOPs used by the model
"""
def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'],
input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None):
if not isinstance(profilers, Sequence):
profilers = [profilers]
if any(p not in ['fvcore', 'deepspeed'] for p in profilers):
raise NotImplementedError('Only support fvcore and deepspeed profilers')
if 'fvcore' in profilers and not has_fvcore_profiling:
raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`')
elif 'deepspeed' in profilers and not has_deepspeed_profiling:
raise ImportError('deepspeed is not installed')
super().__init__()
self.profilers = profilers
self.input_size = tuple(input_size)
self.input_dtype = input_dtype
self.device = device
@rank_zero_only
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
if 'fvcore' in self.profilers:
_, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size,
input_dtype=self.input_dtype, detailed=True)
trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6})
if 'deepspeed' in self.profilers:
macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size,
input_dtype=self.input_dtype, detailed=True)
if 'fvcore' not in self.profilers: # fvcore's MACs seem more accurate
trainer.logger.log_hyperparams({'GMACs': macs * 1e-9})
import torch
from pytorch_lightning import Callback, Trainer, LightningModule
import logging
log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0
def l2_promote():
import ctypes
_libcudart = ctypes.CDLL('libcudart.so')
# Set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
assert pValue.contents.value == 128
def set_affinity(trainer):
try:
from src.utils.gpu_affinity import set_affinity
nproc_per_node = torch.cuda.device_count()
affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous')
log.info(f'{trainer.local_rank}: thread affinity: {affinity}')
# TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per
# number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan.
# l2_promote()
except:
pass
class GpuAffinity(Callback):
"""Set GPU affinity and increase the L2 fetch granularity.
Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL
"""
def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None:
set_affinity(trainer)
# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py.
from typing import Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
class LossScaleMonitor(Callback):
"""Monitor the loss scale for AMP (fp16).
"""
# Use on_before_optimizer_step instead of on_train_batch_start since there might be
# gradient accumulation and we only care about the loss scale when it could change (i.e.,
# optimizer.step).
@rank_zero_only
def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None:
if not trainer._logger_connector.should_update_logs:
return
stats = {}
if isinstance(trainer.strategy, DeepSpeedStrategy):
stats = {'scalar/scale': trainer.model.optimizer.loss_scale}
if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'):
scaler = trainer.precision_plugin.scaler
if scaler is not None:
stats = {
'scaler/scale': scaler.get_scale(),
'scaler/growth_tracker': scaler._get_growth_tracker(),
}
if stats and trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py
from typing import Any
from pathlib import Path
import pytorch_lightning as pl
class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint):
def __init__(self, *args, fault_tolerant=False, **kwargs):
super().__init__(*args, **kwargs)
self.fault_tolerant = fault_tolerant
def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None:
if self.fault_tolerant:
# overwrite if necessary
trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt'))
# def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None:
# if self.fault_tolerant:
# trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt'))
# TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant.
# However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work.
# So I decided to just copy _FaultToleranceCheckpoint and just save on_exception.
# def on_save_checkpoint(self, checkpoint):
# # TD [2022-07-12] The "completed" counter is off by 1 so when it resumes
# # it's off by 1 iteration. However, the data is still off by 1 iteration, probably
# # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how
# # to fix it cleanly.
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1
# checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1
# checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1
# Inspired by https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/grads.py
# However, they compute grad at every iteration (I think), and the .item() calls incur a lot of overhead
# (6-7% slow down on GPT-2 small). Instead we only compute for iterations where we need to log, and don't
# call .item() explicitly.
from typing import Any
from collections import OrderedDict
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import torch
import torch.nn as nn
try:
from apex.contrib.layer_norm import FastLayerNorm
except ImportError:
FastLayerNorm = None
class NormMonitor(Callback):
"""Monitor the scales of weights and gradients.
"""
def __init__(self, layer_norm_only: bool = False):
super().__init__()
self.layer_norm_only = layer_norm_only
# Use on_before_optimizer_step instead of on_train_batch_start since there might be
# gradient accumulation and we only care about scale when it could change (i.e., optimizer.step).
@rank_zero_only
def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args: Any, **kwargs: Any) -> None:
if not trainer._logger_connector.should_update_logs:
return
model = pl_module.model
named_parameters = {}
if self.layer_norm_only:
ln_modules = (nn.LayerNorm, nn.Embedding)
if FastLayerNorm is not None:
ln_modules += (FastLayerNorm,)
for mn, m in model.named_modules():
if isinstance(m, ln_modules):
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
named_parameters[fpn] = p
else:
named_parameters = dict(model.named_parameters())
if isinstance(trainer.strategy, DeepSpeedStrategy):
loss_scale = trainer.model.optimizer.loss_scale
else:
loss_scale = 1.0
stats = {}
param_l1_norm, grad_l1_norm = [], []
for param_name, param in named_parameters.items():
param_abs = param.abs()
param_abs_mean = param_abs.mean(dtype=torch.float32)
stats[f'stats/{param_name}_max'] = param_abs.max()
stats[f'stats/{param_name}_mean'] = param_abs_mean
param_l1_norm.append(param_abs_mean * param.numel())
if param.grad is not None:
# If using AMP, gradient is already unscaled by the AMP loss scaler at this point
# https://github.com/Lightning-AI/lightning/pull/9606
# However, if using DeepSpeed, we need to scale it ourselves
param_grad_abs = param.grad.abs()
param_grad_abs_mean = param_grad_abs.mean(dtype=torch.float32) / loss_scale
stats[f'stats/{param_name}_grad_max'] = param_grad_abs.max() / loss_scale
stats[f'stats/{param_name}_grad_mean'] = param_grad_abs_mean
grad_l1_norm.append(param_grad_abs_mean * param.grad.numel())
stats['total_param_l1_norm'] = torch.stack(param_l1_norm).sum()
if grad_l1_norm:
stats['total_grad_l1_norm'] = torch.stack(grad_l1_norm).sum()
# Sort by params name
stats = OrderedDict(sorted(stats.items()))
if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
from typing import Any
from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
class ParamsLog(Callback):
"""Log the number of parameters of the model
"""
def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True,
non_trainable_params_log: bool = True):
super().__init__()
self._log_stats = AttributeDict(
{
'total_params_log': total_params_log,
'trainable_params_log': trainable_params_log,
'non_trainable_params_log': non_trainable_params_log,
}
)
@rank_zero_only
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
logs = {}
if self._log_stats.total_params_log:
logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters())
if self._log_stats.trainable_params_log:
logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters()
if p.requires_grad)
if self._log_stats.non_trainable_params_log:
logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters()
if not p.requires_grad)
if trainer.logger is not None:
trainer.logger.log_hyperparams(logs)
# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
class SpeedMonitor(Callback):
"""Monitor the speed of each step and each epoch.
"""
def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True,
epoch_time: bool = True, verbose=False):
super().__init__()
self._log_stats = AttributeDict(
{
'intra_step_time': intra_step_time,
'inter_step_time': inter_step_time,
'epoch_time': epoch_time,
}
)
self.verbose = verbose
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_epoch_time = None
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_intra_step_time = None
self._snap_inter_step_time = None
self._snap_epoch_time = time.time()
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_inter_step_time = None
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_inter_step_time = None
@rank_zero_only
def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
if not trainer._logger_connector.should_update_logs:
return
logs = {}
if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time:
pl_module.print(f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}")
if not trainer._logger_connector.should_update_logs:
return
logs = {}
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",) -> None:
logs = {}
if self._log_stats.epoch_time and self._snap_epoch_time:
logs["time/epoch (s)"] = time.time() - self._snap_epoch_time
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
import subprocess
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import seaborn as sn
import torch
import wandb
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get Weights&Biases logger from Trainer."""
if trainer.fast_dev_run:
raise Exception(
"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
)
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if isinstance(trainer.logger, LoggerCollection):
for logger in trainer.logger:
if isinstance(logger, WandbLogger):
return logger
raise Exception(
"You are using wandb related callback, but WandbLogger was not found for some reason..."
)
class WatchModel(Callback):
"""Make wandb watch model at the beginning of the run."""
def __init__(self, log: str = "gradients", log_freq: int = 100):
self.log = log
self.log_freq = log_freq
@rank_zero_only
def on_train_start(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
class UploadCodeAsArtifact(Callback):
"""Upload all code files to wandb as an artifact, at the beginning of the run."""
def __init__(self, code_dir: str, use_git: bool = True):
"""
Args:
code_dir: the code directory
use_git: if using git, then upload all files that are not ignored by git.
if not using git, then upload all '*.py' file
"""
self.code_dir = code_dir
self.use_git = use_git
@rank_zero_only
def on_train_start(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
code = wandb.Artifact("project-source", type="code")
if self.use_git:
# get .git folder
# https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/
git_dir_path = Path(
subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8")
).resolve()
for path in Path(self.code_dir).resolve().rglob("*"):
if (
path.is_file()
# ignore files in .git
and not str(path).startswith(str(git_dir_path)) # noqa: W503
# ignore files ignored by git
and ( # noqa: W503
subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1
)
):
code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
else:
for path in Path(self.code_dir).resolve().rglob("*.py"):
code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
experiment.log_artifact(code)
class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoints to wandb as an artifact, at the end of run."""
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
self.ckpt_dir = ckpt_dir
self.upload_best_only = upload_best_only
@rank_zero_only
def on_keyboard_interrupt(self, trainer, pl_module):
self.on_train_end(trainer, pl_module)
@rank_zero_only
def on_train_end(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
if self.upload_best_only:
ckpts.add_file(trainer.checkpoint_callback.best_model_path)
else:
for path in Path(self.ckpt_dir).rglob("*.ckpt"):
ckpts.add_file(str(path))
experiment.log_artifact(ckpts)
class LogConfusionMatrix(Callback):
"""Generate confusion matrix every epoch and send it to wandb.
Expects validation step to return predictions and targets.
"""
def __init__(self):
self.preds = []
self.targets = []
self.ready = True
def on_sanity_check_start(self, trainer, pl_module) -> None:
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
"""Gather data from single batch."""
if self.ready:
self.preds.append(outputs["preds"])
self.targets.append(outputs["targets"])
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate confusion matrix."""
if self.ready:
logger = get_wandb_logger(trainer)
experiment = logger.experiment
preds = torch.cat(self.preds).cpu().numpy()
targets = torch.cat(self.targets).cpu().numpy()
confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds)
# set figure size
plt.figure(figsize=(14, 8))
# set labels size
sn.set(font_scale=1.4)
# set font size
sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g")
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False)
# according to wandb docs this should also work but it crashes
# experiment.log(f{"confusion_matrix/{experiment.name}": plt})
# reset plot
plt.clf()
self.preds.clear()
self.targets.clear()
class LogF1PrecRecHeatmap(Callback):
"""Generate f1, precision, recall heatmap every epoch and send it to wandb.
Expects validation step to return predictions and targets.
"""
def __init__(self, class_names: List[str] = None):
self.preds = []
self.targets = []
self.ready = True
def on_sanity_check_start(self, trainer, pl_module):
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
"""Gather data from single batch."""
if self.ready:
self.preds.append(outputs["preds"])
self.targets.append(outputs["targets"])
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate f1, precision and recall heatmap."""
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
preds = torch.cat(self.preds).cpu().numpy()
targets = torch.cat(self.targets).cpu().numpy()
f1 = f1_score(targets, preds, average=None)
r = recall_score(targets, preds, average=None)
p = precision_score(targets, preds, average=None)
data = [f1, p, r]
# set figure size
plt.figure(figsize=(14, 3))
# set labels size
sn.set(font_scale=1.2)
# set font size
sn.heatmap(
data,
annot=True,
annot_kws={"size": 10},
fmt=".3f",
yticklabels=["F1", "Precision", "Recall"],
)
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False)
# reset plot
plt.clf()
self.preds.clear()
self.targets.clear()
class LogImagePredictions(Callback):
"""Logs a validation batch and their predictions to wandb.
Example adapted from:
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
"""
def __init__(self, num_samples: int = 8):
super().__init__()
self.num_samples = num_samples
self.ready = True
def on_sanity_check_start(self, trainer, pl_module):
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_epoch_end(self, trainer, pl_module):
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
# get a validation batch from the validation dat loader
val_samples = next(iter(trainer.datamodule.val_dataloader()))
val_imgs, val_labels = val_samples
# run the batch through the network
val_imgs = val_imgs.to(device=pl_module.device)
logits = pl_module(val_imgs)
preds = torch.argmax(logits, dim=-1)
# log the images as wandb Image
experiment.log(
{
f"Images/{experiment.name}": [
wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
for x, pred, y in zip(
val_imgs[: self.num_samples],
preds[: self.num_samples],
val_labels[: self.num_samples],
)
]
}
)
# Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py
# Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py
"""
Handle detokenization for different dataset for zero-shot LM evaluation.
"""
import re
def wikitext_detokenize(string: str) -> str:
"""
Wikitext is whitespace tokenized and we remove these whitespaces.
Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py
"""
# Contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# Number Separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# Punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# Double Brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# Miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
# Set Registry for Various Datasets
DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize}
# Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py
# Except we don't pad the last block and don't use overlapping eval
# And we return both the input and the target
import math
import numpy as np
import torch
class LMDataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, drop_last=True):
"""tokens should be a numpy array
"""
self.seq_len = seq_len
ntokens = len(tokens)
if drop_last:
ntokens = ((ntokens - 1) // seq_len) * seq_len + 1
self.ntokens = ntokens
# We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset,
# and slicing would load it to memory.
self.tokens = tokens
self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len)
def __len__(self):
return self.total_sequences
def __getitem__(self, idx):
start_idx = idx * self.seq_len
seq_len = min(self.seq_len, self.ntokens - 1 - start_idx)
data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64))
return data[:-1], data[1:].clone()
# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397
from typing import Iterator
import math
import torch
from torch.utils.data import RandomSampler, DistributedSampler
class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, generator=None, **kwargs):
# generator = torch.Generator().manual_seed(seed)
# super().__init__(*args, generator=generator, **kwargs)
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
# which should be reproducible if pl.seed_everything was called before hand.
# This means that changing the seed of the experiment will also change the
# sampling order.
if generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"random_state": self.state, "counter": self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self):
# # We need a separate self.start_counter because PL seems to call len repeatedly.
# # If we use len(self.data_source) - self.counter then PL will think the epoch ends
# # when we're only half way through.
# return len(self.data_source) - self.start_counter
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
class FaultTolerantDistributedSampler(DistributedSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"epoch": self.epoch, "counter": self.counter}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self) -> int:
# return self.num_samples - self.start_counter
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
# Adapted from https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/imagenet_datamodule.py
import os
from pathlib import Path
from typing import Any, List, Union, Callable, Optional
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning import LightningDataModule
from torchvision import transforms
from torchvision.datasets import ImageFolder
class DictDataset(Dataset):
def __init__(self, dataset_dict, length=None):
"""dataset_dict: dictionary mapping from index to batch
length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but
with 8 GPUs the dataset_dict would only have 125 items.
"""
super().__init__()
self.dataset_dict = dataset_dict
self.length = length or len(self.dataset_dict)
def __getitem__(self, index):
return self.dataset_dict[index]
def __len__(self):
return self.length
# From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10
def imagenet_normalization():
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class ImagenetDataModule(LightningDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/
Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png
:width: 400
:alt: Imagenet
Specs:
- 1000 classes
- Each image is (3 x varies x varies) (here we default to 3 x 224 x 224)
Imagenet train, val and test dataloaders.
The train set is the imagenet train.
The val set is taken from the train set with `num_imgs_per_val_class` images per class.
For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set.
The test set is the official imagenet validation set.
Example::
from pl_bolts.datamodules import ImagenetDataModule
dm = ImagenetDataModule(IMAGENET_PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""
name = "imagenet"
def __init__(
self,
data_dir: str,
image_size: int = 224,
train_transforms=None,
val_transforms=None,
test_transforms=None,
img_dtype='float32', # Using str since OmegaConf doesn't support non-primitive type
cache_val_dataset=False,
mixup: Optional[Callable] = None,
num_aug_repeats: int = 0,
num_workers: int = 0,
batch_size: int = 32,
batch_size_eval: Optional[int] = None,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: path to the imagenet dataset file
num_imgs_per_val_class: how many images per class for the validation set
image_size: final image size
num_workers: how many data workers
batch_size: batch_size
shuffle: If true shuffles the data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)
self.image_size = image_size
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
assert img_dtype in ['float32', 'float16', 'bfloat16']
self.img_dtype = torch.__getattribute__(img_dtype)
self.cache_val_dataset = cache_val_dataset
self.mixup = mixup
self.num_aug_repeats = num_aug_repeats
self.dims = (3, self.image_size, self.image_size)
self.data_dir = Path(data_dir).expanduser()
self.num_workers = num_workers
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
@property
def num_classes(self) -> int:
"""
Return:
1000
"""
return 1000
def _verify_splits(self, data_dir: str, split: str) -> None:
dirs = os.listdir(data_dir)
if split not in dirs:
raise FileNotFoundError(
f"a {split} Imagenet split was not found in {data_dir},"
f" make sure the folder contains a subfolder named {split}"
)
def prepare_data(self) -> None:
"""This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.
.. warning:: Please download imagenet on your own first.
"""
self._verify_splits(self.data_dir, "train")
self._verify_splits(self.data_dir, "val")
def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""
if stage == "fit" or stage is None:
train_transforms = (self.train_transform() if self.train_transforms is None
else self.train_transforms)
val_transforms = (self.val_transform() if self.val_transforms is None
else self.val_transforms)
if self.img_dtype is not torch.float32:
assert isinstance(train_transforms, transforms.Compose)
assert isinstance(val_transforms, transforms.Compose)
convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))
train_transforms.transforms.append(convert_dtype)
val_transforms.transforms.append(convert_dtype)
self.dataset_train = ImageFolder(self.data_dir / 'train', transform=train_transforms)
self.dataset_val = ImageFolder(self.data_dir / 'val', transform=val_transforms)
if stage == "test" or stage is None:
test_transforms = (self.val_transform() if self.test_transforms is None
else self.test_transforms)
if self.img_dtype is not torch.float32:
assert isinstance(test_transforms, transforms.Compose)
convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))
test_transforms.transforms.append(convert_dtype)
self.dataset_test = ImageFolder(self.data_dir / 'val', transform=test_transforms)
def train_transform(self) -> Callable:
"""The standard imagenet transforms.
.. code-block:: python
transforms.Compose([
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transforms.Compose(
[
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing
def val_transform(self) -> Callable:
"""The standard imagenet transforms for validation.
.. code-block:: python
transforms.Compose([
transforms.Resize(self.image_size + 32),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transforms.Compose(
[
transforms.Resize(self.image_size + 32),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.num_aug_repeats == 0:
shuffle = self.shuffle
sampler = None
else:
shuffle = False
from timm.data.distributed_sampler import RepeatAugSampler
sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats)
return self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, mixup=self.mixup, sampler=sampler)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
# If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to
# construct the DistributedSampler ourselves.
if not self.cache_val_dataset:
sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,
sampler=sampler)
else:
print('Caching val dataset')
sampler = (SequentialSampler(self.dataset_val) if self.trainer.world_size <= 1
else DistributedSampler(self.dataset_val, shuffle=False,
drop_last=self.drop_last))
indices = list(iter(sampler))
loader = DataLoader(self.dataset_val, batch_size=None, shuffle=False, sampler=sampler,
num_workers=self.num_workers, drop_last=self.drop_last)
batches = list(loader)
assert len(batches) == len(indices)
self.dataset_val = DictDataset(dict(zip(indices, batches)),
length=len(self.dataset_val))
sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,
sampler=sampler)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
sampler = (DistributedSampler(self.dataset_test, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler)
def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
mixup: Optional[Callable] = None, sampler=None) -> DataLoader:
collate_fn = ((lambda batch: mixup(*default_collate(batch))) if mixup is not None
else default_collate)
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
persistent_workers=True
)
class Imagenet21kPDataModule(ImagenetDataModule):
"""ImageNet-21k (winter 21) processed with https://github.com/Alibaba-MIIL/ImageNet21K
"""
@property
def num_classes(self) -> int:
"""
Return:
10450
"""
return 10450
# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
from itertools import chain
from pathlib import Path
import pickle
from typing import Any, List, Union
import subprocess
import mmap
from multiprocessing.shared_memory import SharedMemory
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from pytorch_lightning import LightningDataModule
from src.datamodules.datasets.lm_dataset import LMDataset
from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler
from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler
from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY
from src.utils.utils import get_logger
logger = get_logger()
# https://github.com/numpy/numpy/issues/18294
class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
def __new__(cls, input_array, shm=None):
obj = np.asarray(input_array).view(cls)
obj.shm = shm
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.shm = getattr(obj, 'shm', None)
class LMDataModule(LightningDataModule):
def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024,
cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True,
detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None,
use_shmem=True):
super().__init__()
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.tokenizer_name = tokenizer_name
self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser()
self.max_length = max_length
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
self.use_shmem = use_shmem
if self.use_shmem:
assert cache_dir is not None
def prepare_data(self):
if self.cache_dir is None: # Just download the dataset
load_dataset(self.dataset_name, self.dataset_config_name)
else: # Process the dataset and save it
self.process_dataset()
def setup(self, stage=None):
if stage == 'test' and hasattr(self, 'dataset_test'):
return
concat_ids, self.tokenizer = self.process_dataset()
self.vocab_size = len(self.tokenizer)
# Create all splits
self.dataset_train, self.dataset_val, self.dataset_test = [
LMDataset(concat_ids[split], seq_len=self.max_length)
for split in ['train', 'validation', 'test']
]
def process_dataset(self):
cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name
if cache_dir is not None:
if cache_dir.is_dir():
return self._load_from_cache(cache_dir)
raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name)
# https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py
if 'validation' not in raw_datasets:
assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets"
raw_datasets = raw_datasets["train"].train_test_split(
test_size=self.val_ratio, seed=self.val_split_seed,
shuffle=True # Otherwise test will be at the end of the dataset
)
raw_datasets['validation'] = raw_datasets['test']
if self.val_only: # Should only be used for evaluation, not for training
raw_datasets['train'] = raw_datasets['validation']
# [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse
# (GPT2-small val ppl after 10 epochs ~22 -> ~25)
# However, it's useful for zero-shot transfer from Openwebtext,
# as after detokenization it's closer to Openwebtext's format.
# https://github.com/stanford-crfm/mistral/issues/12
if self.detokenize:
if self.dataset_name in DATASET_TOKENIZATION_REGISTRY:
detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name]
raw_datasets = raw_datasets.map(
lambda example: {'text': detokenizer(example['text'])},
num_proc=max(self.num_workers, 1),
desc='Running detokenizer on dataset'
)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
# [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends
# with '\n', and there are no other '\n' in the examples.
# assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t])
# Add EOS token to the end of the text if the text is not empty
# https://github.com/stanford-crfm/mistral/issues/91
# https://github.com/stanford-crfm/mistral/pull/98
if self.add_eos:
add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq
add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs]
tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name]))
else:
tokenize = lambda example: tokenizer(example[text_column_name])
# tokenized_datasets = raw_datasets.map(
# tokenize,
# batched=True,
# num_proc=max(self.num_workers, 1),
# remove_columns=column_names,
# desc="Running tokenizer on dataset",
# )
dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32
def tokenize_concat(examples):
# We just need 'input_ids', not 'attention_mask' (since it's all 1)
input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype)
# Need to return a list since we're doing batched processing
return {'input_ids': [input_ids], 'len': [len(input_ids)]}
tokenized_datasets = raw_datasets.map(
tokenize_concat,
batched=True,
num_proc=max(self.num_workers, 1),
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
if self.use_shmem:
# Concatenate all input_ids into an array in shared memory
def write_ids_to_shm(example, shm_name, array_len):
shm = SharedMemory(name=shm_name)
shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)
start_idx = example['len_offset'] - len(example['input_ids'])
shm_arr[start_idx:example['len_offset']] = example['input_ids']
shm.close()
concat_ids = {}
for name, ds in tokenized_datasets.items():
tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))
array_len = tokenized_datasets[name][-1]['len_offset']
shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize)
shm_name = shm.name
tokenized_datasets[name].map(
write_ids_to_shm,
fn_kwargs={'shm_name': shm_name, 'array_len': array_len},
batched=False,
num_proc=max(self.num_workers, 1),
desc="Concatenating examples",
)
shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)
# We need to keep a reference to the shared memory, otherwise it gets garbage-collected
# when it goes out of scope, and that memory is gone.
# https://github.com/numpy/numpy/issues/18294
concat_ids[name] = SHMArray(shm_arr, shm=shm)
else:
# Use disk
concat_ids = {}
assert cache_dir is not None
cache_dir.mkdir(parents=True, exist_ok=True)
def write_ids_to_disk(example, filename):
with open(filename, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0)
start_idx = example['len_offset'] - len(example['input_ids'])
array_len = len(example['input_ids'])
arr = np.ndarray((array_len,), dtype=dtype, buffer=mm,
offset=np.dtype(dtype).itemsize * start_idx)
arr[:] = example['input_ids']
mm.flush()
for name, ds in tokenized_datasets.items():
tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))
array_len = tokenized_datasets[name][-1]['len_offset']
filename = cache_dir / f'{name}.bin'
# Need to create the file with this specific size first
# https://ostechnix.com/create-files-certain-size-linux/
subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize),
str(filename)], check=True)
tokenized_datasets[name].map(
write_ids_to_disk,
fn_kwargs={'filename': filename},
batched=False,
num_proc=max(self.num_workers, 1),
desc="Concatenating examples",
)
concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,))
if cache_dir is not None:
self._save_to_cache(concat_ids, tokenizer, cache_dir)
if not self.use_shmem:
for name in concat_ids:
Path(cache_dir / f'{name}.bin').unlink()
return concat_ids, tokenizer
def _save_to_cache(self, concat_ids, tokenizer, cache_dir):
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f'Saving to cache at {str(cache_dir)}')
for k, v in concat_ids.items():
np.save(cache_dir / f'{k}.npy', v)
with open(cache_dir / 'tokenizer.pkl', 'wb') as f:
pickle.dump(tokenizer, f)
def _load_from_cache(self, cache_dir):
assert cache_dir.is_dir()
logger.info(f'Load from cache at {str(cache_dir)}')
concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r')
for split in ['train', 'validation', 'test']}
with open(cache_dir / 'tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
return concat_ids, tokenizer
@property
def _cache_dir_name(self):
return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}'
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.shuffle and self.fault_tolerant:
shuffle = False
sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp
else RandomFaultTolerantSampler(self.dataset_train))
# TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now
# We assume that it's being resumed with the same number of GPUs
if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:
sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': self.fast_forward_batches * self.batch_size
})
else:
shuffle = self.shuffle
sampler = None
return self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, sampler=sampler)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval)
def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
sampler=None) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=1, # Data is already in memory, we don't need many workers
shuffle=shuffle,
sampler=sampler,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
# persistent_workers=True
)
def load_state_dict(self, checkpoint):
if self.fault_tolerant:
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
# TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress. This is set correctly in seq.py.
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
# At this point the train loader hasn't been constructed yet
import torch
from timm.data import Mixup
from timm.data.mixup import mixup_target
class TimmMixup(Mixup):
""" Wrap timm.data.Mixup that avoids the assert that batch size must be even.
"""
def __call__(self, x, target):
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
# We move the assert from the beginning of the function to here
assert len(x) % 2 == 0, 'Batch size should be even when using this'
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
return x, target
# Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html
# We divide by world_size first before converting to fp16, so it's safer.
from typing import Any, Callable
import torch
import torch.distributed as dist
def fp16_compress_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
This DDP communication hook implements a simple gradient compression
approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
and then divides it by the process group size.
It allreduces those ``float16`` gradient tensors. Once compressed gradient
tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
Example::
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# Divide first before converting to fp16
# Use out argument to fuse the division and the conversion.
compressed_tensor = torch.div(bucket.buffer(), world_size,
out=torch.empty_like(bucket.buffer(), dtype=torch.float16))
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
def decompress(fut):
decompressed_tensor = bucket.buffer()
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value()[0])
return decompressed_tensor
# TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case
# resend with fp32?
return fut.then(decompress)
from typing import List, Optional
from pathlib import Path
import torch
import hydra
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src.utils import utils
log = utils.get_logger(__name__)
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
return text # or whatever
def load_checkpoint(path, device='cpu'):
path = Path(path).expanduser()
if path.is_dir():
path /= 'last.ckpt'
# dst = f'cuda:{torch.cuda.current_device()}'
log.info(f'Loading checkpoint from {str(path)}')
state_dict = torch.load(path, map_location=device)
# T2T-ViT checkpoint is nested in the key 'state_dict_ema'
if state_dict.keys() == {'state_dict_ema'}:
state_dict = state_dict['state_dict_ema']
# Swin checkpoint is nested in the key 'model'
if state_dict.keys() == {'model'}:
state_dict = state_dict['model']
# Lightning checkpoint contains extra stuff, we only want the model state dict
if 'pytorch-lightning_version' in state_dict:
state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()}
return state_dict
def evaluate(config: DictConfig) -> None:
"""Example of inference with trained model.
It loads trained image classification model from checkpoint.
Then it loads example image and predicts its label.
"""
# load model from checkpoint
# model __init__ parameters will be loaded from ckpt automatically
# you can also pass some parameter explicitly to override it
# We want to add fields to config so need to call OmegaConf.set_struct
OmegaConf.set_struct(config, False)
# load model
checkpoint_type = config.eval.get('checkpoint_type', 'pytorch')
if checkpoint_type not in ['lightning', 'pytorch']:
raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported')
if checkpoint_type == 'lightning':
cls = hydra.utils.get_class(config.task._target_)
model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt)
elif checkpoint_type == 'pytorch':
model_cfg = config.model_pretrained if 'model_pretrained' in config else None
trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config,
model_cfg=model_cfg,
_recursive_=False)
if 'ckpt' in config.eval:
load_return = trained_model.model.load_state_dict(
load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False
)
log.info(load_return)
if 'model_pretrained' in config:
...
else:
model = trained_model
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
# datamodule: LightningDataModule = model._datamodule
datamodule.prepare_data()
datamodule.setup()
# print model hyperparameters
log.info(f'Model hyperparameters: {model.hparams}')
# Init Lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config["callbacks"].items():
if cb_conf is not None and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init Lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config["logger"].items():
if lg_conf is not None and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
# Init Lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
# Evaluate the model
log.info("Starting evaluation!")
if config.eval.get('run_val', True):
trainer.validate(model=model, datamodule=datamodule)
if config.eval.get('run_test', True):
trainer.test(model=model, datamodule=datamodule)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment