"vscode:/vscode.git/clone" did not exist on "d38f5c8ce93e0c33d5d2a87b1b1a096c2179a9e2"
Commit 0bf5e500 authored by Tri Dao's avatar Tri Dao
Browse files

Release training code

parent 9bc63d1e
from typing import Callable
import dotenv
import hydra
from omegaconf import OmegaConf, DictConfig
# load environment variables from `.env` file if it exists
# recursively searches for `.env` in all folders starting from work dir
dotenv.load_dotenv(override=True)
OmegaConf.register_new_resolver('eval', eval)
OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)
# Delay the evaluation until we have the datamodule
# So we want the resolver to yield the same string.
OmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}')
# Turn on TensorFloat32
import torch.backends
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig:
"""Only keep keys where fn(key) is True. Support nested DictConfig.
"""
# Using d.items_ex(resolve=False) instead of d.items() since we want to keep the
# ${datamodule:foo} unresolved for now.
return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v
# for k, v in d.items_ex(resolve=False) if fn(k)})
for k, v in d.items() if fn(k)})
@hydra.main(config_path="configs/", config_name="config.yaml")
def main(config: DictConfig):
# Remove config keys that start with '__'. These are meant to be used only in computing
# other entries in the config.
config = dictconfig_filter_key(config, lambda k: not k.startswith('__'))
# Imports should be nested inside @hydra.main to optimize tab completion
# Read more here: https://github.com/facebookresearch/hydra/issues/934
from src.train import train
from src.eval import evaluate
from src.utils import utils
# A couple of optional utilities:
# - disabling python warnings
# - forcing debug-friendly configuration
# - verifying experiment name is set when running in experiment mode
# You can safely get rid of this line if you don't want those
utils.extras(config)
# Pretty print config using Rich library
if config.get("print_config"):
utils.print_config(config, resolve=True)
# Train model
mode = config.get('mode', 'train')
if mode not in ['train', 'eval']:
raise NotImplementedError(f'mode {mode} not supported')
if mode == 'train':
return train(config)
elif mode == 'eval':
return evaluate(config)
if __name__ == "__main__":
main()
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_only
import torch
from torch.autograd import grad
class CausalityMonitor(Callback):
r"""Monitor causality of a model by tracking gradient leakage forward in time.
In a fully causal model, dy[k]du[s] ~= 0 for all k < s.
Args:
seq_len (int): Length of the sequence to monitor.
input_dim (int): Dimension of the input to monitor. If 0, the callback assumes
the task to be language modeling, and skips the embedding layer. If > 0,
input_dim is interpreted as the input channel dimension, i.e. D with
dummy input of dimension [B, L, D].
Notes:
This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute,
indicating the primary model to monitor. For LMs, `net` or `s4seq` should
be after the embedding layer.
"""
def __init__(self, seq_len: int = 10, input_dim: int = 0):
super().__init__()
self.seq_len = seq_len
self.input_dim = input_dim
@rank_zero_only
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
model = pl_module.model
with torch.enable_grad():
if self.input_dim == 0:
# [MP] LongTensors cannot have gradients - we start from post
# embedding in the LM case
input_dim = model.d_model
x = torch.randn((2, self.seq_len, input_dim), \
requires_grad=True).to(pl_module.device)
# [DF] HACK: we need to get the layer that comes after the embedding
if hasattr(model, 'net'):
y = model.net(x)
else:
y = model.s4seq(x)
else:
x = torch.randn(1, self.seq_len, self.input_dim, \
requires_grad=True).to(pl_module.device)
y = model(x)
stats = {}
for i in range(self.seq_len):
# total gradients flowing from y_i to x
g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0]
g = g[0,i+1:,:].abs().mean()
stats[f'stats/causality_{i}'] = g.item()
if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.global_step)
# Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py
# https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py
# https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2
# https://github.com/PyTorchLightning/pytorch-lightning/issues/8100
from typing import Dict, Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
from src.utils.ema import ExponentialMovingAverage
class EMACallback(Callback):
"""TD [2021-08-31]: saving and loading from checkpoint should work.
"""
def __init__(self, decay: float, use_num_updates: bool = True):
"""
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
super().__init__()
self.decay = decay
self.use_num_updates = use_num_updates
self.ema = None
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
# It's possible that we already loaded EMA from the checkpoint
if self.ema is None:
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
decay=self.decay, use_num_updates=self.use_num_updates)
# Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it
# We only want to update when parameters are changing.
# Because of gradient accumulation, this doesn't happen every training step.
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11688
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if (batch_idx + 1) % trainer.accumulate_grad_batches == 0:
self.ema.update()
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# During the initial validation we don't have self.ema yet
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.restore()
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.restore()
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
return self.ema.state_dict()
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
if self.ema is None:
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
decay=self.decay, use_num_updates=self.use_num_updates)
self.ema.load_state_dict(checkpoint)
# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py
from typing import Any, List, Sequence
import torch
from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling
from src.utils.flops import profile_deepspeed, profile_fvcore
class FlopCount(Callback):
"""Counter the number of FLOPs used by the model
"""
def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'],
input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None):
if not isinstance(profilers, Sequence):
profilers = [profilers]
if any(p not in ['fvcore', 'deepspeed'] for p in profilers):
raise NotImplementedError('Only support fvcore and deepspeed profilers')
if 'fvcore' in profilers and not has_fvcore_profiling:
raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`')
elif 'deepspeed' in profilers and not has_deepspeed_profiling:
raise ImportError('deepspeed is not installed')
super().__init__()
self.profilers = profilers
self.input_size = tuple(input_size)
self.input_dtype = input_dtype
self.device = device
@rank_zero_only
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
if 'fvcore' in self.profilers:
_, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size,
input_dtype=self.input_dtype, detailed=True)
trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6})
if 'deepspeed' in self.profilers:
macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size,
input_dtype=self.input_dtype, detailed=True)
if 'fvcore' not in self.profilers: # fvcore's MACs seem more accurate
trainer.logger.log_hyperparams({'GMACs': macs * 1e-9})
import torch
from pytorch_lightning import Callback, Trainer, LightningModule
import logging
log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0
def l2_promote():
import ctypes
_libcudart = ctypes.CDLL('libcudart.so')
# Set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
assert pValue.contents.value == 128
def set_affinity(trainer):
try:
from src.utils.gpu_affinity import set_affinity
nproc_per_node = torch.cuda.device_count()
affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous')
log.info(f'{trainer.local_rank}: thread affinity: {affinity}')
# TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per
# number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan.
# l2_promote()
except:
pass
class GpuAffinity(Callback):
"""Set GPU affinity and increase the L2 fetch granularity.
Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL
"""
def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None:
set_affinity(trainer)
# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py.
from typing import Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
class LossScaleMonitor(Callback):
"""Monitor the loss scale for AMP (fp16).
"""
# Use on_before_optimizer_step instead of on_train_batch_start since there might be
# gradient accumulation and we only care about the loss scale when it could change (i.e.,
# optimizer.step).
@rank_zero_only
def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None:
if not trainer._logger_connector.should_update_logs:
return
stats = {}
if isinstance(trainer.strategy, DeepSpeedStrategy):
stats = {'scalar/scale': trainer.model.optimizer.loss_scale}
if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'):
scaler = trainer.precision_plugin.scaler
if scaler is not None:
stats = {
'scaler/scale': scaler.get_scale(),
'scaler/growth_tracker': scaler._get_growth_tracker(),
}
if stats and trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py
from typing import Any
from pathlib import Path
import pytorch_lightning as pl
class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint):
def __init__(self, *args, fault_tolerant=False, **kwargs):
super().__init__(*args, **kwargs)
self.fault_tolerant = fault_tolerant
def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None:
if self.fault_tolerant:
# overwrite if necessary
trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt'))
# def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None:
# if self.fault_tolerant:
# trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt'))
# TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant.
# However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work.
# So I decided to just copy _FaultToleranceCheckpoint and just save on_exception.
# def on_save_checkpoint(self, checkpoint):
# # TD [2022-07-12] The "completed" counter is off by 1 so when it resumes
# # it's off by 1 iteration. However, the data is still off by 1 iteration, probably
# # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how
# # to fix it cleanly.
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1
# checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1
# checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1
# checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1
# Inspired by https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/grads.py
# However, they compute grad at every iteration (I think), and the .item() calls incur a lot of overhead
# (6-7% slow down on GPT-2 small). Instead we only compute for iterations where we need to log, and don't
# call .item() explicitly.
from typing import Any
from collections import OrderedDict
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import torch
import torch.nn as nn
try:
from apex.contrib.layer_norm import FastLayerNorm
except ImportError:
FastLayerNorm = None
class NormMonitor(Callback):
"""Monitor the scales of weights and gradients.
"""
def __init__(self, layer_norm_only: bool = False):
super().__init__()
self.layer_norm_only = layer_norm_only
# Use on_before_optimizer_step instead of on_train_batch_start since there might be
# gradient accumulation and we only care about scale when it could change (i.e., optimizer.step).
@rank_zero_only
def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args: Any, **kwargs: Any) -> None:
if not trainer._logger_connector.should_update_logs:
return
model = pl_module.model
named_parameters = {}
if self.layer_norm_only:
ln_modules = (nn.LayerNorm, nn.Embedding)
if FastLayerNorm is not None:
ln_modules += (FastLayerNorm,)
for mn, m in model.named_modules():
if isinstance(m, ln_modules):
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
named_parameters[fpn] = p
else:
named_parameters = dict(model.named_parameters())
if isinstance(trainer.strategy, DeepSpeedStrategy):
loss_scale = trainer.model.optimizer.loss_scale
else:
loss_scale = 1.0
stats = {}
param_l1_norm, grad_l1_norm = [], []
for param_name, param in named_parameters.items():
param_abs = param.abs()
param_abs_mean = param_abs.mean(dtype=torch.float32)
stats[f'stats/{param_name}_max'] = param_abs.max()
stats[f'stats/{param_name}_mean'] = param_abs_mean
param_l1_norm.append(param_abs_mean * param.numel())
if param.grad is not None:
# If using AMP, gradient is already unscaled by the AMP loss scaler at this point
# https://github.com/Lightning-AI/lightning/pull/9606
# However, if using DeepSpeed, we need to scale it ourselves
param_grad_abs = param.grad.abs()
param_grad_abs_mean = param_grad_abs.mean(dtype=torch.float32) / loss_scale
stats[f'stats/{param_name}_grad_max'] = param_grad_abs.max() / loss_scale
stats[f'stats/{param_name}_grad_mean'] = param_grad_abs_mean
grad_l1_norm.append(param_grad_abs_mean * param.grad.numel())
stats['total_param_l1_norm'] = torch.stack(param_l1_norm).sum()
if grad_l1_norm:
stats['total_grad_l1_norm'] = torch.stack(grad_l1_norm).sum()
# Sort by params name
stats = OrderedDict(sorted(stats.items()))
if trainer.loggers is not None:
for logger in trainer.loggers:
logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
from typing import Any
from pytorch_lightning import Callback, Trainer, LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
class ParamsLog(Callback):
"""Log the number of parameters of the model
"""
def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True,
non_trainable_params_log: bool = True):
super().__init__()
self._log_stats = AttributeDict(
{
'total_params_log': total_params_log,
'trainable_params_log': trainable_params_log,
'non_trainable_params_log': non_trainable_params_log,
}
)
@rank_zero_only
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
logs = {}
if self._log_stats.total_params_log:
logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters())
if self._log_stats.trainable_params_log:
logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters()
if p.requires_grad)
if self._log_stats.non_trainable_params_log:
logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters()
if not p.requires_grad)
if trainer.logger is not None:
trainer.logger.log_hyperparams(logs)
# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor
# We only need the speed monitoring, not the GPU monitoring
import time
from typing import Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
class SpeedMonitor(Callback):
"""Monitor the speed of each step and each epoch.
"""
def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True,
epoch_time: bool = True, verbose=False):
super().__init__()
self._log_stats = AttributeDict(
{
'intra_step_time': intra_step_time,
'inter_step_time': inter_step_time,
'epoch_time': epoch_time,
}
)
self.verbose = verbose
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_epoch_time = None
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_intra_step_time = None
self._snap_inter_step_time = None
self._snap_epoch_time = time.time()
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_inter_step_time = None
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._snap_inter_step_time = None
@rank_zero_only
def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
if not trainer._logger_connector.should_update_logs:
return
logs = {}
if self._log_stats.inter_step_time and self._snap_inter_step_time:
# First log at beginning of second step
logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()
if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time:
pl_module.print(f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}")
if not trainer._logger_connector.should_update_logs:
return
logs = {}
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",) -> None:
logs = {}
if self._log_stats.epoch_time and self._snap_epoch_time:
logs["time/epoch (s)"] = time.time() - self._snap_epoch_time
if trainer.logger is not None:
trainer.logger.log_metrics(logs, step=trainer.global_step)
import subprocess
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import seaborn as sn
import torch
import wandb
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get Weights&Biases logger from Trainer."""
if trainer.fast_dev_run:
raise Exception(
"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
)
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if isinstance(trainer.logger, LoggerCollection):
for logger in trainer.logger:
if isinstance(logger, WandbLogger):
return logger
raise Exception(
"You are using wandb related callback, but WandbLogger was not found for some reason..."
)
class WatchModel(Callback):
"""Make wandb watch model at the beginning of the run."""
def __init__(self, log: str = "gradients", log_freq: int = 100):
self.log = log
self.log_freq = log_freq
@rank_zero_only
def on_train_start(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
class UploadCodeAsArtifact(Callback):
"""Upload all code files to wandb as an artifact, at the beginning of the run."""
def __init__(self, code_dir: str, use_git: bool = True):
"""
Args:
code_dir: the code directory
use_git: if using git, then upload all files that are not ignored by git.
if not using git, then upload all '*.py' file
"""
self.code_dir = code_dir
self.use_git = use_git
@rank_zero_only
def on_train_start(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
code = wandb.Artifact("project-source", type="code")
if self.use_git:
# get .git folder
# https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/
git_dir_path = Path(
subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8")
).resolve()
for path in Path(self.code_dir).resolve().rglob("*"):
if (
path.is_file()
# ignore files in .git
and not str(path).startswith(str(git_dir_path)) # noqa: W503
# ignore files ignored by git
and ( # noqa: W503
subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1
)
):
code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
else:
for path in Path(self.code_dir).resolve().rglob("*.py"):
code.add_file(str(path), name=str(path.relative_to(self.code_dir)))
experiment.log_artifact(code)
class UploadCheckpointsAsArtifact(Callback):
"""Upload checkpoints to wandb as an artifact, at the end of run."""
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
self.ckpt_dir = ckpt_dir
self.upload_best_only = upload_best_only
@rank_zero_only
def on_keyboard_interrupt(self, trainer, pl_module):
self.on_train_end(trainer, pl_module)
@rank_zero_only
def on_train_end(self, trainer, pl_module):
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
if self.upload_best_only:
ckpts.add_file(trainer.checkpoint_callback.best_model_path)
else:
for path in Path(self.ckpt_dir).rglob("*.ckpt"):
ckpts.add_file(str(path))
experiment.log_artifact(ckpts)
class LogConfusionMatrix(Callback):
"""Generate confusion matrix every epoch and send it to wandb.
Expects validation step to return predictions and targets.
"""
def __init__(self):
self.preds = []
self.targets = []
self.ready = True
def on_sanity_check_start(self, trainer, pl_module) -> None:
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
"""Gather data from single batch."""
if self.ready:
self.preds.append(outputs["preds"])
self.targets.append(outputs["targets"])
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate confusion matrix."""
if self.ready:
logger = get_wandb_logger(trainer)
experiment = logger.experiment
preds = torch.cat(self.preds).cpu().numpy()
targets = torch.cat(self.targets).cpu().numpy()
confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds)
# set figure size
plt.figure(figsize=(14, 8))
# set labels size
sn.set(font_scale=1.4)
# set font size
sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g")
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False)
# according to wandb docs this should also work but it crashes
# experiment.log(f{"confusion_matrix/{experiment.name}": plt})
# reset plot
plt.clf()
self.preds.clear()
self.targets.clear()
class LogF1PrecRecHeatmap(Callback):
"""Generate f1, precision, recall heatmap every epoch and send it to wandb.
Expects validation step to return predictions and targets.
"""
def __init__(self, class_names: List[str] = None):
self.preds = []
self.targets = []
self.ready = True
def on_sanity_check_start(self, trainer, pl_module):
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
"""Gather data from single batch."""
if self.ready:
self.preds.append(outputs["preds"])
self.targets.append(outputs["targets"])
def on_validation_epoch_end(self, trainer, pl_module):
"""Generate f1, precision and recall heatmap."""
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
preds = torch.cat(self.preds).cpu().numpy()
targets = torch.cat(self.targets).cpu().numpy()
f1 = f1_score(targets, preds, average=None)
r = recall_score(targets, preds, average=None)
p = precision_score(targets, preds, average=None)
data = [f1, p, r]
# set figure size
plt.figure(figsize=(14, 3))
# set labels size
sn.set(font_scale=1.2)
# set font size
sn.heatmap(
data,
annot=True,
annot_kws={"size": 10},
fmt=".3f",
yticklabels=["F1", "Precision", "Recall"],
)
# names should be uniqe or else charts from different experiments in wandb will overlap
experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False)
# reset plot
plt.clf()
self.preds.clear()
self.targets.clear()
class LogImagePredictions(Callback):
"""Logs a validation batch and their predictions to wandb.
Example adapted from:
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
"""
def __init__(self, num_samples: int = 8):
super().__init__()
self.num_samples = num_samples
self.ready = True
def on_sanity_check_start(self, trainer, pl_module):
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_validation_epoch_end(self, trainer, pl_module):
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
# get a validation batch from the validation dat loader
val_samples = next(iter(trainer.datamodule.val_dataloader()))
val_imgs, val_labels = val_samples
# run the batch through the network
val_imgs = val_imgs.to(device=pl_module.device)
logits = pl_module(val_imgs)
preds = torch.argmax(logits, dim=-1)
# log the images as wandb Image
experiment.log(
{
f"Images/{experiment.name}": [
wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
for x, pred, y in zip(
val_imgs[: self.num_samples],
preds[: self.num_samples],
val_labels[: self.num_samples],
)
]
}
)
# Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py
# Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py
"""
Handle detokenization for different dataset for zero-shot LM evaluation.
"""
import re
def wikitext_detokenize(string: str) -> str:
"""
Wikitext is whitespace tokenized and we remove these whitespaces.
Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py
"""
# Contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# Number Separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# Punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# Double Brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# Miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
# Set Registry for Various Datasets
DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize}
# Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py
# Except we don't pad the last block and don't use overlapping eval
# And we return both the input and the target
import math
import numpy as np
import torch
class LMDataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, drop_last=True):
"""tokens should be a numpy array
"""
self.seq_len = seq_len
ntokens = len(tokens)
if drop_last:
ntokens = ((ntokens - 1) // seq_len) * seq_len + 1
self.ntokens = ntokens
# We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset,
# and slicing would load it to memory.
self.tokens = tokens
self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len)
def __len__(self):
return self.total_sequences
def __getitem__(self, idx):
start_idx = idx * self.seq_len
seq_len = min(self.seq_len, self.ntokens - 1 - start_idx)
data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64))
return data[:-1], data[1:].clone()
# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397
from typing import Iterator
import math
import torch
from torch.utils.data import RandomSampler, DistributedSampler
class RandomFaultTolerantSampler(RandomSampler):
def __init__(self, *args, generator=None, **kwargs):
# generator = torch.Generator().manual_seed(seed)
# super().__init__(*args, generator=generator, **kwargs)
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
# which should be reproducible if pl.seed_everything was called before hand.
# This means that changing the seed of the experiment will also change the
# sampling order.
if generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator().manual_seed(seed)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"random_state": self.state, "counter": self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self):
# # We need a separate self.start_counter because PL seems to call len repeatedly.
# # If we use len(self.data_source) - self.counter then PL will think the epoch ends
# # when we're only half way through.
# return len(self.data_source) - self.start_counter
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
class FaultTolerantDistributedSampler(DistributedSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
# self.start_counter = 0
self.restarting = False
def state_dict(self):
return {"epoch": self.epoch, "counter": self.counter}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
# def __len__(self) -> int:
# return self.num_samples - self.start_counter
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
# self.start_counter = self.counter
for index in indices:
self.counter += 1
yield index
self.counter = 0
# self.start_counter = self.counter
# Adapted from https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/imagenet_datamodule.py
import os
from pathlib import Path
from typing import Any, List, Union, Callable, Optional
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning import LightningDataModule
from torchvision import transforms
from torchvision.datasets import ImageFolder
class DictDataset(Dataset):
def __init__(self, dataset_dict, length=None):
"""dataset_dict: dictionary mapping from index to batch
length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but
with 8 GPUs the dataset_dict would only have 125 items.
"""
super().__init__()
self.dataset_dict = dataset_dict
self.length = length or len(self.dataset_dict)
def __getitem__(self, index):
return self.dataset_dict[index]
def __len__(self):
return self.length
# From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10
def imagenet_normalization():
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class ImagenetDataModule(LightningDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/
Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png
:width: 400
:alt: Imagenet
Specs:
- 1000 classes
- Each image is (3 x varies x varies) (here we default to 3 x 224 x 224)
Imagenet train, val and test dataloaders.
The train set is the imagenet train.
The val set is taken from the train set with `num_imgs_per_val_class` images per class.
For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set.
The test set is the official imagenet validation set.
Example::
from pl_bolts.datamodules import ImagenetDataModule
dm = ImagenetDataModule(IMAGENET_PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""
name = "imagenet"
def __init__(
self,
data_dir: str,
image_size: int = 224,
train_transforms=None,
val_transforms=None,
test_transforms=None,
img_dtype='float32', # Using str since OmegaConf doesn't support non-primitive type
cache_val_dataset=False,
mixup: Optional[Callable] = None,
num_aug_repeats: int = 0,
num_workers: int = 0,
batch_size: int = 32,
batch_size_eval: Optional[int] = None,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: path to the imagenet dataset file
num_imgs_per_val_class: how many images per class for the validation set
image_size: final image size
num_workers: how many data workers
batch_size: batch_size
shuffle: If true shuffles the data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)
self.image_size = image_size
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
assert img_dtype in ['float32', 'float16', 'bfloat16']
self.img_dtype = torch.__getattribute__(img_dtype)
self.cache_val_dataset = cache_val_dataset
self.mixup = mixup
self.num_aug_repeats = num_aug_repeats
self.dims = (3, self.image_size, self.image_size)
self.data_dir = Path(data_dir).expanduser()
self.num_workers = num_workers
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
@property
def num_classes(self) -> int:
"""
Return:
1000
"""
return 1000
def _verify_splits(self, data_dir: str, split: str) -> None:
dirs = os.listdir(data_dir)
if split not in dirs:
raise FileNotFoundError(
f"a {split} Imagenet split was not found in {data_dir},"
f" make sure the folder contains a subfolder named {split}"
)
def prepare_data(self) -> None:
"""This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.
.. warning:: Please download imagenet on your own first.
"""
self._verify_splits(self.data_dir, "train")
self._verify_splits(self.data_dir, "val")
def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""
if stage == "fit" or stage is None:
train_transforms = (self.train_transform() if self.train_transforms is None
else self.train_transforms)
val_transforms = (self.val_transform() if self.val_transforms is None
else self.val_transforms)
if self.img_dtype is not torch.float32:
assert isinstance(train_transforms, transforms.Compose)
assert isinstance(val_transforms, transforms.Compose)
convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))
train_transforms.transforms.append(convert_dtype)
val_transforms.transforms.append(convert_dtype)
self.dataset_train = ImageFolder(self.data_dir / 'train', transform=train_transforms)
self.dataset_val = ImageFolder(self.data_dir / 'val', transform=val_transforms)
if stage == "test" or stage is None:
test_transforms = (self.val_transform() if self.test_transforms is None
else self.test_transforms)
if self.img_dtype is not torch.float32:
assert isinstance(test_transforms, transforms.Compose)
convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))
test_transforms.transforms.append(convert_dtype)
self.dataset_test = ImageFolder(self.data_dir / 'val', transform=test_transforms)
def train_transform(self) -> Callable:
"""The standard imagenet transforms.
.. code-block:: python
transforms.Compose([
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transforms.Compose(
[
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing
def val_transform(self) -> Callable:
"""The standard imagenet transforms for validation.
.. code-block:: python
transforms.Compose([
transforms.Resize(self.image_size + 32),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transforms.Compose(
[
transforms.Resize(self.image_size + 32),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.num_aug_repeats == 0:
shuffle = self.shuffle
sampler = None
else:
shuffle = False
from timm.data.distributed_sampler import RepeatAugSampler
sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats)
return self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, mixup=self.mixup, sampler=sampler)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
# If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to
# construct the DistributedSampler ourselves.
if not self.cache_val_dataset:
sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,
sampler=sampler)
else:
print('Caching val dataset')
sampler = (SequentialSampler(self.dataset_val) if self.trainer.world_size <= 1
else DistributedSampler(self.dataset_val, shuffle=False,
drop_last=self.drop_last))
indices = list(iter(sampler))
loader = DataLoader(self.dataset_val, batch_size=None, shuffle=False, sampler=sampler,
num_workers=self.num_workers, drop_last=self.drop_last)
batches = list(loader)
assert len(batches) == len(indices)
self.dataset_val = DictDataset(dict(zip(indices, batches)),
length=len(self.dataset_val))
sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,
sampler=sampler)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
sampler = (DistributedSampler(self.dataset_test, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler)
def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
mixup: Optional[Callable] = None, sampler=None) -> DataLoader:
collate_fn = ((lambda batch: mixup(*default_collate(batch))) if mixup is not None
else default_collate)
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
persistent_workers=True
)
class Imagenet21kPDataModule(ImagenetDataModule):
"""ImageNet-21k (winter 21) processed with https://github.com/Alibaba-MIIL/ImageNet21K
"""
@property
def num_classes(self) -> int:
"""
Return:
10450
"""
return 10450
# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
from itertools import chain
from pathlib import Path
import pickle
from typing import Any, List, Union
import subprocess
import mmap
from multiprocessing.shared_memory import SharedMemory
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from pytorch_lightning import LightningDataModule
from src.datamodules.datasets.lm_dataset import LMDataset
from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler
from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler
from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY
from src.utils.utils import get_logger
logger = get_logger()
# https://github.com/numpy/numpy/issues/18294
class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
def __new__(cls, input_array, shm=None):
obj = np.asarray(input_array).view(cls)
obj.shm = shm
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.shm = getattr(obj, 'shm', None)
class LMDataModule(LightningDataModule):
def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024,
cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True,
detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None,
use_shmem=True):
super().__init__()
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.tokenizer_name = tokenizer_name
self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser()
self.max_length = max_length
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
self.use_shmem = use_shmem
if self.use_shmem:
assert cache_dir is not None
def prepare_data(self):
if self.cache_dir is None: # Just download the dataset
load_dataset(self.dataset_name, self.dataset_config_name)
else: # Process the dataset and save it
self.process_dataset()
def setup(self, stage=None):
if stage == 'test' and hasattr(self, 'dataset_test'):
return
concat_ids, self.tokenizer = self.process_dataset()
self.vocab_size = len(self.tokenizer)
# Create all splits
self.dataset_train, self.dataset_val, self.dataset_test = [
LMDataset(concat_ids[split], seq_len=self.max_length)
for split in ['train', 'validation', 'test']
]
def process_dataset(self):
cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name
if cache_dir is not None:
if cache_dir.is_dir():
return self._load_from_cache(cache_dir)
raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name)
# https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py
if 'validation' not in raw_datasets:
assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets"
raw_datasets = raw_datasets["train"].train_test_split(
test_size=self.val_ratio, seed=self.val_split_seed,
shuffle=True # Otherwise test will be at the end of the dataset
)
raw_datasets['validation'] = raw_datasets['test']
if self.val_only: # Should only be used for evaluation, not for training
raw_datasets['train'] = raw_datasets['validation']
# [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse
# (GPT2-small val ppl after 10 epochs ~22 -> ~25)
# However, it's useful for zero-shot transfer from Openwebtext,
# as after detokenization it's closer to Openwebtext's format.
# https://github.com/stanford-crfm/mistral/issues/12
if self.detokenize:
if self.dataset_name in DATASET_TOKENIZATION_REGISTRY:
detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name]
raw_datasets = raw_datasets.map(
lambda example: {'text': detokenizer(example['text'])},
num_proc=max(self.num_workers, 1),
desc='Running detokenizer on dataset'
)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
# [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends
# with '\n', and there are no other '\n' in the examples.
# assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t])
# Add EOS token to the end of the text if the text is not empty
# https://github.com/stanford-crfm/mistral/issues/91
# https://github.com/stanford-crfm/mistral/pull/98
if self.add_eos:
add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq
add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs]
tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name]))
else:
tokenize = lambda example: tokenizer(example[text_column_name])
# tokenized_datasets = raw_datasets.map(
# tokenize,
# batched=True,
# num_proc=max(self.num_workers, 1),
# remove_columns=column_names,
# desc="Running tokenizer on dataset",
# )
dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32
def tokenize_concat(examples):
# We just need 'input_ids', not 'attention_mask' (since it's all 1)
input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype)
# Need to return a list since we're doing batched processing
return {'input_ids': [input_ids], 'len': [len(input_ids)]}
tokenized_datasets = raw_datasets.map(
tokenize_concat,
batched=True,
num_proc=max(self.num_workers, 1),
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
if self.use_shmem:
# Concatenate all input_ids into an array in shared memory
def write_ids_to_shm(example, shm_name, array_len):
shm = SharedMemory(name=shm_name)
shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)
start_idx = example['len_offset'] - len(example['input_ids'])
shm_arr[start_idx:example['len_offset']] = example['input_ids']
shm.close()
concat_ids = {}
for name, ds in tokenized_datasets.items():
tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))
array_len = tokenized_datasets[name][-1]['len_offset']
shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize)
shm_name = shm.name
tokenized_datasets[name].map(
write_ids_to_shm,
fn_kwargs={'shm_name': shm_name, 'array_len': array_len},
batched=False,
num_proc=max(self.num_workers, 1),
desc="Concatenating examples",
)
shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)
# We need to keep a reference to the shared memory, otherwise it gets garbage-collected
# when it goes out of scope, and that memory is gone.
# https://github.com/numpy/numpy/issues/18294
concat_ids[name] = SHMArray(shm_arr, shm=shm)
else:
# Use disk
concat_ids = {}
assert cache_dir is not None
cache_dir.mkdir(parents=True, exist_ok=True)
def write_ids_to_disk(example, filename):
with open(filename, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0)
start_idx = example['len_offset'] - len(example['input_ids'])
array_len = len(example['input_ids'])
arr = np.ndarray((array_len,), dtype=dtype, buffer=mm,
offset=np.dtype(dtype).itemsize * start_idx)
arr[:] = example['input_ids']
mm.flush()
for name, ds in tokenized_datasets.items():
tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))
array_len = tokenized_datasets[name][-1]['len_offset']
filename = cache_dir / f'{name}.bin'
# Need to create the file with this specific size first
# https://ostechnix.com/create-files-certain-size-linux/
subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize),
str(filename)], check=True)
tokenized_datasets[name].map(
write_ids_to_disk,
fn_kwargs={'filename': filename},
batched=False,
num_proc=max(self.num_workers, 1),
desc="Concatenating examples",
)
concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,))
if cache_dir is not None:
self._save_to_cache(concat_ids, tokenizer, cache_dir)
if not self.use_shmem:
for name in concat_ids:
Path(cache_dir / f'{name}.bin').unlink()
return concat_ids, tokenizer
def _save_to_cache(self, concat_ids, tokenizer, cache_dir):
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f'Saving to cache at {str(cache_dir)}')
for k, v in concat_ids.items():
np.save(cache_dir / f'{k}.npy', v)
with open(cache_dir / 'tokenizer.pkl', 'wb') as f:
pickle.dump(tokenizer, f)
def _load_from_cache(self, cache_dir):
assert cache_dir.is_dir()
logger.info(f'Load from cache at {str(cache_dir)}')
concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r')
for split in ['train', 'validation', 'test']}
with open(cache_dir / 'tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
return concat_ids, tokenizer
@property
def _cache_dir_name(self):
return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}'
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.shuffle and self.fault_tolerant:
shuffle = False
sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp
else RandomFaultTolerantSampler(self.dataset_train))
# TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now
# We assume that it's being resumed with the same number of GPUs
if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:
sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': self.fast_forward_batches * self.batch_size
})
else:
shuffle = self.shuffle
sampler = None
return self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, sampler=sampler)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval)
def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
sampler=None) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=1, # Data is already in memory, we don't need many workers
shuffle=shuffle,
sampler=sampler,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
# persistent_workers=True
)
def load_state_dict(self, checkpoint):
if self.fault_tolerant:
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
# TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress. This is set correctly in seq.py.
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
# At this point the train loader hasn't been constructed yet
import torch
from timm.data import Mixup
from timm.data.mixup import mixup_target
class TimmMixup(Mixup):
""" Wrap timm.data.Mixup that avoids the assert that batch size must be even.
"""
def __call__(self, x, target):
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
# We move the assert from the beginning of the function to here
assert len(x) % 2 == 0, 'Batch size should be even when using this'
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
return x, target
# Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html
# We divide by world_size first before converting to fp16, so it's safer.
from typing import Any, Callable
import torch
import torch.distributed as dist
def fp16_compress_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
This DDP communication hook implements a simple gradient compression
approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
and then divides it by the process group size.
It allreduces those ``float16`` gradient tensors. Once compressed gradient
tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).
Example::
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# Divide first before converting to fp16
# Use out argument to fuse the division and the conversion.
compressed_tensor = torch.div(bucket.buffer(), world_size,
out=torch.empty_like(bucket.buffer(), dtype=torch.float16))
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
def decompress(fut):
decompressed_tensor = bucket.buffer()
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value()[0])
return decompressed_tensor
# TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case
# resend with fp32?
return fut.then(decompress)
from typing import List, Optional
from pathlib import Path
import torch
import hydra
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src.utils import utils
log = utils.get_logger(__name__)
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
return text # or whatever
def load_checkpoint(path, device='cpu'):
path = Path(path).expanduser()
if path.is_dir():
path /= 'last.ckpt'
# dst = f'cuda:{torch.cuda.current_device()}'
log.info(f'Loading checkpoint from {str(path)}')
state_dict = torch.load(path, map_location=device)
# T2T-ViT checkpoint is nested in the key 'state_dict_ema'
if state_dict.keys() == {'state_dict_ema'}:
state_dict = state_dict['state_dict_ema']
# Swin checkpoint is nested in the key 'model'
if state_dict.keys() == {'model'}:
state_dict = state_dict['model']
# Lightning checkpoint contains extra stuff, we only want the model state dict
if 'pytorch-lightning_version' in state_dict:
state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()}
return state_dict
def evaluate(config: DictConfig) -> None:
"""Example of inference with trained model.
It loads trained image classification model from checkpoint.
Then it loads example image and predicts its label.
"""
# load model from checkpoint
# model __init__ parameters will be loaded from ckpt automatically
# you can also pass some parameter explicitly to override it
# We want to add fields to config so need to call OmegaConf.set_struct
OmegaConf.set_struct(config, False)
# load model
checkpoint_type = config.eval.get('checkpoint_type', 'pytorch')
if checkpoint_type not in ['lightning', 'pytorch']:
raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported')
if checkpoint_type == 'lightning':
cls = hydra.utils.get_class(config.task._target_)
model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt)
elif checkpoint_type == 'pytorch':
model_cfg = config.model_pretrained if 'model_pretrained' in config else None
trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config,
model_cfg=model_cfg,
_recursive_=False)
if 'ckpt' in config.eval:
load_return = trained_model.model.load_state_dict(
load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False
)
log.info(load_return)
if 'model_pretrained' in config:
...
else:
model = trained_model
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
# datamodule: LightningDataModule = model._datamodule
datamodule.prepare_data()
datamodule.setup()
# print model hyperparameters
log.info(f'Model hyperparameters: {model.hparams}')
# Init Lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config["callbacks"].items():
if cb_conf is not None and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init Lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config["logger"].items():
if lg_conf is not None and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
# Init Lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
# Evaluate the model
log.info("Starting evaluation!")
if config.eval.get('run_val', True):
trainer.validate(model=model, datamodule=datamodule)
if config.eval.get('run_test', True):
trainer.test(model=model, datamodule=datamodule)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment