Commit 0bf5e500 authored by Tri Dao's avatar Tri Dao
Browse files

Release training code

parent 9bc63d1e
import torch
import torch.nn as nn
import xentropy_cuda_lib
# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
losses, max_log_sum_exp = xentropy_cuda_lib.forward(
logits, labels, smoothing)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels)
ctx.smoothing = smoothing
ctx.padding_idx = padding_idx
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
ctx.smoothing, ctx.inplace_backward)
return grad_logits, None, None, None, None
class CrossEntropyLossApex(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False):
super().__init__()
if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
self.ignore_index, self.inplace_backward)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py
# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and
# the losses we can get the global loss. There's no need to do it step by step
# (compute local max, exchange, compute exp, compute local sum, exchange, etc.)
import torch
import torch.nn as nn
import xentropy_cuda_lib
from apex.transformer.parallel_state import get_tensor_model_parallel_group
from apex.transformer.parallel_state import get_tensor_model_parallel_rank
from apex.transformer.parallel_state import get_tensor_model_parallel_world_size
from apex.transformer.tensor_parallel.utils import VocabUtility
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward comparability with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
@staticmethod
def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100,
inplace_backward=False):
"""
logits_parallel: (batch, vocab_size / world_size)
labels: (batch,)
"""
assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it'
batch, partition_vocab_size = logits_parallel.shape
assert labels.shape == (batch,)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
partition_vocab_size, get_tensor_model_parallel_rank(),
get_tensor_model_parallel_world_size()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
masked_labels = labels_local.clone()
masked_labels[labels_mask] = ignored_index
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(masked_labels==ignored_index, 0)
if world_size > 1:
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
group=get_tensor_model_parallel_group())
lse = torch.logsumexp(lse_allgather, dim=0)
torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group())
# The losses are currently lse_local - predicted_logit, we just have to subtract the
# lse_local and add the lse (global).
rank_per_sample = labels // partition_vocab_size
lse_local = lse_allgather[rank_per_sample,
torch.arange(batch, device=lse_allgather.device)]
losses += lse - lse_local
losses.masked_fill_(ignored_mask, 0)
else:
lse = lse_local
ctx.save_for_backward(logits_parallel, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits_parallel, lse, labels = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels,
ctx.smoothing, ctx.inplace_backward)
return grad_logits, None, None, None, None, None
class CrossEntropyLossParallel(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False):
super().__init__()
if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossParallelFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward
)
if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
import torch
from torch import Tensor
from torchmetrics import Metric, Accuracy
class AccuracyMine(Accuracy):
"""Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup.
"""
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target)
from typing import Any, Dict, Optional
import torch
from torch import Tensor
from torchmetrics import Metric
class NumTokens(Metric):
"""Keep track of how many tokens we've seen.
"""
# TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch
# of the next epoch.
# Right now the hack is that we override reset(), which would mess up the forward method.
# We then override forward to do the right thing.
is_differentiable = False
higher_is_better = False
full_state_update = False
count: Tensor
def __init__(self, **kwargs: Dict[str, Any]):
super().__init__(**kwargs)
self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum",
persistent=True) # We want the count to be saved to state-dict
def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore
self.count += target.numel()
def compute(self) -> Tensor:
return self.count
def reset(self):
count = self.count
super().reset()
self.count = count
# Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py
def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""forward computation using single call to `update` to calculate the metric value on the current batch and
accumulate global state.
This can be done when the global metric state is a sinple reduction of batch states.
"""
self.update(*args, **kwargs)
return self.compute()
# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py
# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll))
# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py
# But we pass in the loss to avoid recomputation
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torchmetrics import Metric
try:
from src.losses.cross_entropy_apex import CrossEntropyLossApex as CrossEntropyLoss
except ImportError:
CrossEntropyLoss = torch.nn.CrossEntropyLoss
__all__ = ['Perplexity']
class Perplexity(Metric):
r"""
Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits
per word a model needs to represent the sample.
Args:
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Examples:
>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> target[0, 6:] = -100
>>> metric = Perplexity(ignore_index=-100)
>>> metric(preds, target)
tensor(5.2545)
"""
is_differentiable = True
higher_is_better = False
full_state_update = False
total_log_probs: Tensor
count: Tensor
def __init__(self, **kwargs: Dict[str, Any]):
super().__init__(**kwargs)
self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64),
dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum")
self.loss_fn = CrossEntropyLoss()
def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore
"""Compute and store intermediate statistics for Perplexity.
Args:
preds:
Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].
target:
Ground truth values with a shape [batch_size, seq_len].
"""
count = target.numel()
if loss is None:
loss = self.loss_fn(preds, target)
self.total_log_probs += loss.double() * count
self.count += count
def compute(self) -> Tensor:
"""Compute the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.total_log_probs / self.count)
import math
from functools import partial
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
import hydra
from einops import reduce, rearrange
def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=True):
if pooling_mode not in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN']:
raise NotImplementedError(f'pooling_mode must be MEAN, SUM, CLS, LAST, FLATTEN')
if pooling_mode in ['MEAN', 'SUM']:
if key_padding_mask is not None:
mask = rearrange(~key_padding_mask.bool_matrix,
'b s -> b s 1' if batch_first else 'b s -> s b 1')
x = x.masked_fill(mask, 0)
s = reduce(x, 'b s ... -> b ...' if batch_first else 's b ... -> b ...', 'sum')
if pooling_mode == 'SUM':
return s
else:
if key_padding_mask is None:
return s / x.shape[1 if batch_first else 0]
else:
lengths = rearrange(key_padding_mask._lengths, 'b -> b 1')
return s / lengths
elif pooling_mode == 'CLS':
return x[:, 0] if batch_first else x[0]
elif pooling_mode == 'LAST':
if key_padding_mask is None:
return x[:, -1] if batch_first else x[-1]
else:
lengths = key_padding_mask._lengths
if batch_first:
batch_size = x.shape[0]
return x[torch.arange(batch_size, device=x.device), lengths - 1]
else:
batch_size = x.shape[1]
return x[lengths - 1, torch.arange(batch_size, device=x.device)]
elif pooling_mode == 'FLATTEN':
return rearrange(x, 'b ... -> b (...)' if batch_first else 's b ... -> b (s ...)')
class ClassificationHeadLinear(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, d_model, num_classes, pooling_mode='MEAN',
batch_first=False, **kwargs):
super().__init__()
assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'
self.pooling_mode = pooling_mode
self.batch_first = batch_first
self.out_proj = nn.Linear(d_model, num_classes)
def forward(self, hidden_states, key_padding_mask=None, **kwargs):
"""
hidden_states: (B, S, D) if batch_first else (S, B, D)
"""
hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,
key_padding_mask=key_padding_mask, batch_first=self.batch_first)
hidden_states = self.out_proj(hidden_states)
return hidden_states
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/models/reformer/modeling_reformer.py
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',
batch_first=False):
super().__init__()
assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'
self.pooling_mode = pooling_mode
self.batch_first = batch_first
self.dense = nn.Linear(d_model, d_inner)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(d_inner, num_classes)
def forward(self, hidden_states, key_padding_mask=None, **kwargs):
"""
hidden_states: (B, S, D) if batch_first else (S, B, D)
"""
hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,
key_padding_mask=key_padding_mask, batch_first=self.batch_first)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
# Huggingface uses tanh instead of relu
hidden_states = torch.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class ClassificationHeadDual(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',
batch_first=False, interaction='NLI'):
super().__init__()
assert pooling_mode in ['MEAN', 'SUM', 'CLS'], 'pooling_mode not supported'
assert interaction in [None, 'NLI'], 'interaction not supported'
self.pooling_mode = pooling_mode
self.batch_first = batch_first
self.interaction = interaction
self.dense = nn.Linear(d_model * (4 if self.interaction == 'NLI' else 2), d_inner)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(d_inner, num_classes)
def forward(self, hidden_states1, hidden_states2,
key_padding_mask1=None, key_padding_mask2=None, **kwargs):
"""
hidden_states: (B, S, D) if batch_first else (S, B, D)
"""
x1 = pooling(hidden_states1, pooling_mode=self.pooling_mode,
key_padding_mask=key_padding_mask1, batch_first=self.batch_first)
x2 = pooling(hidden_states2, pooling_mode=self.pooling_mode,
key_padding_mask=key_padding_mask2, batch_first=self.batch_first)
hidden_states = (torch.cat([x1, x2, x1 * x2, x1 - x2], dim=-1) if self.interaction == 'NLI'
else torch.cat([x1, x2], dim=-1))
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
# Huggingface uses tanh instead of relu
hidden_states = torch.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class LMHead(nn.Module):
def __init__(self, d_model, num_classes, batch_first=True, bias=True):
super().__init__()
self.lm_head = nn.Linear(d_model, num_classes, bias=bias)
def forward(self, hidden_states, **kwargs):
"""
hidden_states: (B, S, D) if batch_first else (S, B, D)
"""
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(self.lm_head(hidden_states))
def sinusoidal_init_(tensor):
"""
tensor: (max_len, d_model)
"""
max_len, d_model = tensor.shape
position = rearrange(torch.arange(0.0, max_len), 's -> s 1')
div_term = torch.exp(-math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model)
tensor[:, 0::2] = torch.sin(position * div_term)
tensor[:, 1::2] = torch.cos(position * div_term)
return tensor
# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False, initializer=None):
super().__init__()
self.batch_first = batch_first
self.dropout = nn.Dropout(p=dropout)
pe = torch.empty(max_len, d_model)
if initializer is None:
sinusoidal_init_(pe)
pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')
self.register_buffer('pe', pe)
else:
hydra.utils.call(initializer, pe)
pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')
self.pe = nn.Parameter(pe)
def forward(self, x):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]
output: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]
Examples:
>>> output = pos_encoder(x)
"""
x = x + (self.pe[:, :x.size(1)] if self.batch_first else self.pe[:x.size(0)])
return self.dropout(x)
# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
act_fn=None, drop=0., device=None, dtype=None):
"""TD [2021-10-27] act_fn takes precedence over act_layer if set.
This is to support Pytorch 1.10 Transformer interface that construct the activation
*function*, not the activation *layer*.
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = _pair(drop)
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.act = act_layer() if act_fn is None else act_fn
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class MlpBig(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
act_fn=None, drop=0., device=None, dtype=None):
"""Copied from Mlp above. If num_layers > 2, add more Mlp layers, doubling each time.
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
cur_hidden_features = hidden_features
layers = []
for _ in range(4):
layers.append(nn.Linear(in_features, cur_hidden_features, **factory_kwargs))
layers.append(act_layer())
layers.append(nn.Dropout(drop))
in_features = cur_hidden_features
cur_hidden_features *= 2
layers.append(nn.Linear(in_features, out_features, **factory_kwargs))
layers.append(nn.Dropout(drop))
self.fwd = nn.Sequential(*layers)
def forward(self, x):
return self.fwd(x)
class GluMlp(nn.Module):
""" MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features // 2, out_features)
self.drop = nn.Dropout(drop)
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
fc1_mid = self.fc1.bias.shape[0] // 2
nn.init.ones_(self.fc1.bias[fc1_mid:])
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=-1)
x = x * self.act(gates)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
if gate_layer is not None:
assert hidden_features % 2 == 0
self.gate = gate_layer(hidden_features)
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else:
self.gate = nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.gate(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x
import inspect
import torch.nn as nn
import hydra
try:
from apex.contrib.layer_norm import FastLayerNorm
except ImportError:
FastLayerNorm = None
from src.models.modules.seq_common import PositionalEncoding
def group_parameters_for_optimizer(model, optimizer_cfg, bias_weight_decay=False,
normalization_weight_decay=False):
"""Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with
attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for
normalization parameters if normalization_weight_decay==False
"""
# Get the weight decay from the config, or from the default value of the optimizer constructor
# if it's not specified in the config.
if 'weight_decay' in optimizer_cfg:
weight_decay = optimizer_cfg.weight_decay
else:
# https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value
signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_))
if 'weight_decay' in signature.parameters:
weight_decay = signature.parameters['weight_decay'].default
if weight_decay is inspect.Parameter.empty:
weight_decay = 0.0
else:
weight_decay = 0.0
# If none of the parameters have weight decay anyway, and there are no parameters with special
# optimization params
if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()):
return model.parameters()
skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set()
skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords')
else set())
# Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
special = set()
whitelist_weight_modules = (nn.Linear, )
blacklist_weight_modules = (nn.Embedding, PositionalEncoding)
if not normalization_weight_decay:
blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
nn.GroupNorm, nn.SyncBatchNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.LayerNorm, nn.LocalResponseNorm)
if FastLayerNorm is not None:
blacklist_weight_modules += (FastLayerNorm,)
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# In case of parameter sharing, some parameters show up here but are not in
# param_dict.keys()
if not p.requires_grad or fpn not in param_dict:
continue # frozen weights
if hasattr(p, '_optim'):
special.add(fpn)
elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords):
no_decay.add(fpn)
elif getattr(p, '_no_weight_decay', False):
no_decay.add(fpn)
elif not bias_weight_decay and pn.endswith('bias'):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
decay |= (param_dict.keys() - no_decay - special)
# validate that we considered every parameter
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!"
assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
if weight_decay == 0.0 or not no_decay:
param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))],
"weight_decay": weight_decay}]
else:
# We need sorted(list()) so that the order is deterministic. Otherwise when we resume
# the order could change and resume will fail. [H/t Albert]
param_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
# Add parameters with special hyperparameters
# Unique dicts
hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]
for hp in hps:
params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]
param_groups.append({"params": params, **hp})
return param_groups
import torch
from torch.optim import Optimizer
from timm.scheduler import CosineLRScheduler
# We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain
class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):
""" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch.
It supports resuming as well.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_epoch = -1
self.step(epoch=0)
def step(self, epoch=None):
if epoch is None:
self._last_epoch += 1
else:
self._last_epoch = epoch
# We call either step or step_update, depending on whether we're using the scheduler every
# epoch or every step.
# Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set
# scheduler interval to "step", then the learning rate update will be wrong.
if self.t_in_epochs:
super().step(epoch=self._last_epoch)
else:
super().step_update(num_updates=self._last_epoch)
from typing import Any, List
import inspect
import torch
import hydra
from pytorch_lightning import LightningModule, LightningDataModule
from torchmetrics import MetricCollection
from einops import rearrange
from omegaconf import OmegaConf
from src.utils.utils import get_logger
from src.optim.param_grouping import group_parameters_for_optimizer
from src.utils.checkpoint import load_checkpoint
logger = get_logger(__name__)
class SequenceModel(LightningModule):
def __init__(self, cfg, model_cfg=None):
"""If model_cfg is passed, it will take precedence over cfg.model
"""
super().__init__()
# this line ensures params passed to LightningModule will be saved to ckpt
# it also allows to access params with 'self.hparams' attribute
self.save_hyperparameters(cfg)
self.cfg = cfg
self.model_cfg = model_cfg or self.cfg.model
self.instantiate_datamodule()
self.instantiate_model()
self.warmstart()
self.instantiate_loss()
self.instantiate_metrics()
def instantiate_datamodule(self):
logger.info(f"Instantiating datamodule <{self.cfg.datamodule._target_}>")
# Calling this self.datamodule will mess with PL since it also assigns self.datamodule
self._datamodule: LightningDataModule = hydra.utils.instantiate(self.cfg.datamodule)
self._datamodule.prepare_data()
self._datamodule.setup()
OmegaConf.clear_resolver('datamodule')
OmegaConf.register_new_resolver('datamodule', lambda attr: getattr(self._datamodule, attr))
def instantiate_model(self):
# if hasattr(self._datamodule, 'num_classes'):
# self.model_cfg.num_classes = self._datamodule.num_classes
# if (hasattr(self._datamodule, 'vocab_size')
# and self.model_cfg.get('embedding_cfg', None) is not None
# and self.model_cfg.embedding_cfg._target_ == "torch.nn.Embedding"):
# self.model_cfg.embedding_cfg.num_embeddings = self._datamodule.vocab_size
logger.info(f"Instantiating model <{self.model_cfg._target_}>")
recursive = getattr(self.model_cfg, '_recursive_', False)
self.model = hydra.utils.instantiate(self.model_cfg, _recursive_=recursive)
def instantiate_loss(self):
loss_fn_cfg = self.cfg.train.get('loss_fn')
if loss_fn_cfg is None:
loss_fn_cfg = {'_target_': 'torch.nn.CrossEntropyLoss'}
self.loss_fn = hydra.utils.instantiate(loss_fn_cfg)
loss_fn_val_cfg = self.cfg.train.get('loss_fn_val', loss_fn_cfg)
self.loss_fn_val = hydra.utils.instantiate(loss_fn_val_cfg)
def instantiate_metrics(self):
# use separate metric instance for train, val and test step
# to ensure a proper reduction over the epoch
if 'eval' in self.cfg and 'metrics' in self.cfg.eval:
metrics_cfg = self.cfg.eval.metrics
else:
metrics_cfg = {'acc': {'_target_': 'torchmetrics.Accuracy'}}
metrics = MetricCollection({name: hydra.utils.instantiate(cfg)
for name, cfg in metrics_cfg.items()})
self.train_metrics = metrics.clone(prefix='train/')
self.val_metrics = metrics.clone(prefix='val/')
self.test_metrics = metrics.clone(prefix='test/')
def warmstart(self):
if self.cfg.train.get('warmstart', None) is not None:
logger.info(f"Warm-starting with weights from {self.cfg.train.warmstart.path}")
strict = self.cfg.train.warmstart.get('strict', True)
state_dict = load_checkpoint(self.cfg.train.warmstart.path)
if self.cfg.train.warmstart.get('post_process', None) is not None:
state_dict = hydra.utils.instantiate(self.cfg.train.warmstart.post_process,
state_dict)
load_return = self.model.load_state_dict(state_dict, strict=False)
logger.info(load_return)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def step(self, batch: Any, is_train=True):
try:
x, y, lengths = batch
except ValueError:
x, y = batch
lengths = None
output = self.forward(x) if lengths is None else self.forward(x, lengths=lengths)
loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y)
return loss, output, y
def shared_step(self, batch: Any, batch_idx: int, phase='train'):
loss, output, targets = self.step(batch, is_train=(phase == 'train'))
metrics = getattr(self, f'{phase}_metrics')
metrics(output, targets)
log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train'
self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True,
prog_bar=False, sync_dist=True)
# https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training
# We need to log the Metrics object, not the metric result, since otherwise
# pytorch-lightning will use torch.mean to reduce it.
# This would be wrong for perplexity, for example.
self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True)
return {"loss": loss, "output": output, "targets": targets}
def training_step(self, batch: Any, batch_idx: int):
return self.shared_step(batch, batch_idx, phase='train')
def validation_step(self, batch: Any, batch_idx: int):
return self.shared_step(batch, batch_idx, phase='val')
def test_step(self, batch: Any, batch_idx: int):
return self.shared_step(batch, batch_idx, phase='test')
def configure_optimizers(self):
if 'optimizer_param_grouping' in self.cfg.train: # Set zero weight decay for some params
parameters = group_parameters_for_optimizer(self.model, self.cfg.train.optimizer,
**self.cfg.train.optimizer_param_grouping)
else:
# parameters = self.model.parameters()
parameters = self.parameters() # [21-09-08] AG: this will train task specific parameters such as Retrieval head for AAN
optimizer = hydra.utils.instantiate(self.cfg.train.optimizer, parameters)
# Log optimizer info
for i, g in enumerate(optimizer.param_groups):
ntensors = len(g['params'])
nparams = sum(p.numel() for p in g['params'])
hparams = {k: v for k, v in g.items() if k != 'params'}
logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}')
if 'scheduler' not in self.cfg.train:
return optimizer
else:
# lr_scheduler should be called either every step (default) or every epoch
lr_scheduler = hydra.utils.instantiate(self.cfg.train.scheduler, optimizer)
return [optimizer], {'scheduler': lr_scheduler,
'interval': self.cfg.train.get('scheduler_interval', 'step'),
'monitor': self.cfg.train.get('scheduler_monitor', 'val/loss')}
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
# https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-grads-to-none
# TD [2022-04-30]: DeepSpeed optimizer uses the kwarg set_grad_to_none instead of set_to_none
if 'set_to_none' in inspect.signature(optimizer.zero_grad).parameters:
optimizer.zero_grad(set_to_none=True)
else:
optimizer.zero_grad()
def on_save_checkpoint(self, checkpoint):
# TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress.
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] * self.trainer.accumulate_grad_batches
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['current']['completed'] * self.trainer.accumulate_grad_batches
# _batches_that_stepped tracks the number of global steps, not the number
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed']
class SequenceLMModel(SequenceModel):
def step(self, batch: Any, is_train=True):
x, y = batch
output = self.forward(x).logits
output = rearrange(output, '... C -> (...) C')
y = rearrange(y, '... -> (...)')
loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y)
return loss, output, y
def shared_step(self, batch: Any, batch_idx: int, phase='train'):
loss, output, targets = self.step(batch, is_train=(phase == 'train'))
# Passing the loss to the perplexity metrics to avoid recomputation
metrics = getattr(self, f'{phase}_metrics')
metrics(output, targets, loss=loss)
log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train'
self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True,
prog_bar=False, sync_dist=True)
# https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training
# We need to log the Metrics object, not the metric result, since otherwise
# pytorch-lightning will use torch.mean to reduce it.
# This would be wrong for perplexity, for example.
self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True)
return {"loss": loss, "output": output, "targets": targets}
from typing import List, Optional, Sequence
from pathlib import Path
import hydra
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src.utils import utils
log = utils.get_logger(__name__)
def last_modification_time(path):
"""Including files / directory 1-level below the path
"""
path = Path(path)
if path.is_file():
return path.stat().st_mtime
elif path.is_dir():
return max(child.stat().st_mtime for child in path.iterdir())
else:
return None
def train(config: DictConfig) -> Optional[float]:
"""Contains training pipeline.
Instantiates all PyTorch Lightning objects from config.
Args:
config (DictConfig): Configuration composed by Hydra.
Returns:
Optional[float]: Metric score for hyperparameter optimization.
"""
# Set seed for random number generators in pytorch, numpy and python.random
if config.get("seed"):
seed_everything(config.seed, workers=True)
# We want to add fields to config so need to call OmegaConf.set_struct
OmegaConf.set_struct(config, False)
# Init lightning model
model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False)
datamodule: LightningDataModule = model._datamodule
# Init lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config.callbacks.items():
if cb_conf is not None and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if lg_conf is not None and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
ckpt_cfg = {}
if config.get('resume'):
try:
checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath)
if checkpoint_path.is_dir():
last_ckpt = checkpoint_path / 'last.ckpt'
autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt'
if not (last_ckpt.exists() or autosave_ckpt.exists()):
raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt")
if ((not last_ckpt.exists())
or (autosave_ckpt.exists()
and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))):
# autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt'))
checkpoint_path = autosave_ckpt
else:
checkpoint_path = last_ckpt
# DeepSpeed's checkpoint is a directory, not a file
if checkpoint_path.is_file() or checkpoint_path.is_dir():
ckpt_cfg = {'ckpt_path': str(checkpoint_path)}
else:
log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch')
except (KeyError, FileNotFoundError):
pass
# Configure ddp automatically
n_devices = config.trainer.get('devices', 1)
if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example
n_devices = len(n_devices)
if n_devices > 1 and config.trainer.get('strategy', None) is None:
config.trainer.strategy = dict(
_target_='pytorch_lightning.strategies.DDPStrategy',
find_unused_parameters=False,
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
)
# Init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger)
# Train the model
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg)
# Evaluate model on test set, using the best model achieved during training
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
# Print path to best checkpoint
if not config.trainer.get("fast_dev_run"):
log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}")
# Return metric score for hyperparameter optimization
optimized_metric = config.get("optimized_metric")
if optimized_metric:
return trainer.callback_metrics[optimized_metric]
import re
from pathlib import Path
import torch
import math
from einops import rearrange
def load_checkpoint(path, device='cpu'):
path = Path(path).expanduser()
is_deepspeed = False
if path.is_dir(): # DeepSpeed checkpoint
is_deepspeed = True
latest_path = path / 'latest'
if latest_path.is_file():
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
path /= f'{tag}/mp_rank_00_model_states.pt'
state_dict = torch.load(path, map_location=device)
if is_deepspeed:
state_dict = state_dict['module']
# Replace the names of some of the submodules
def key_mapping(key):
return re.sub(r'^module.model.', '', key)
state_dict = {key_mapping(k): v for k, v in state_dict.items()}
return state_dict
def blockdiag_to_dense_mlp_bert(state_dict):
from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight
names = {name for name in state_dict
if re.match('bert.encoder.layer.(\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight',
name)}
for name in names:
state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name])
return state_dict
def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False):
orig_emb = state_dict['state_dict'][pos_embedding_name]
assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length'
reps = [1 for i in orig_emb.shape]
reps[1] = out_seqlen // orig_emb.shape[1]
if interleave:
assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths'
assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths'
assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square'
emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1]))
emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2)
new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d')
state_dict['state_dict'][pos_embedding_name] = new_emb
else:
state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps)
ret = remove_model_prefix(state_dict)
# # HACK: this is a hack for block-sparse flash attention
ret = {
k: v
for k, v in ret.items()
if not k.endswith('inner_attn.layout')
}
return ret
def remove_model_prefix(state_dict):
# HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix
for key in list(state_dict['state_dict'].keys()):
if key.startswith('model.'):
new_key = key[len('model.'):]
state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key)
# HACK: something is wrong with the state dict being loaded...
return state_dict['state_dict']
# Meant to work with Pytorch's ZeroRedundancyOptimizer
from typing import Any, Callable, Dict, List, Optional, Union
from pathlib import Path
import torch
from torch.optim.optimizer import Optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import _PATH
# from lightning_lite.utilities.types import _PATH
# Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get
# the local state dict to avoid synchronization across GPUs.
# https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131
def get_zero_optimizer_state_dict_local(optimizer, global_rank):
optimizer._check_overlap_initialized()
# Sync the exposed `param_groups` attributes to the local optimizer in
# case they have been updated
optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups)
local_state_dict = optimizer.optim.state_dict()
state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict()
# Update the global optimizer state with local state information,
# factoring in the translation from local to global indexing
rank = global_rank
# TODO: recursive copy to device
local_param_groups = local_state_dict["param_groups"]
global_param_groups = optimizer._partition_parameters()[rank]
assert len(local_param_groups) == len(global_param_groups), \
"Mismatch between number of local and global parameter groups"
for local_param_group, global_param_group in zip(local_param_groups, global_param_groups):
# `local_param_group` stores local indices, while
# `global_param_group` stores the tensors directly
local_param_indices = local_param_group["params"]
global_params = global_param_group["params"]
assert len(local_param_indices) == len(global_params), \
"Mismatch between number of local and global parameters in parameter group"
for local_param_index, global_param in zip(local_param_indices, global_params):
# Update the global parameter state, if any
if local_param_index in local_state_dict["state"]:
global_param_index = optimizer._param_to_index[global_param]
state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index]
# Sort the parameters in the state
state_dict["state"] = dict(sorted(state_dict["state"].items()))
return state_dict
class DDPStrategyZero1(DDPStrategy):
"""To use ZeroRedundancyOptimizer, we need to shard the optimizer states when
saving/loading checkpoints.
"""
strategy_name = "ddp_zero1"
def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
if isinstance(optimizer, ZeroRedundancyOptimizer):
return get_zero_optimizer_state_dict_local(optimizer, self.global_rank)
else:
return optimizer.state_dict()
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
"""
filepath = Path(filepath)
filepath.mkdir(parents=True, exist_ok=True)
local_optimizer_states = checkpoint.pop('optimizer_states')
if self.is_global_zero:
self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt',
storage_options=storage_options)
self.checkpoint_io.save_checkpoint(local_optimizer_states,
filepath / f'{self.global_rank:03d}_optim_states.pt',
storage_options=storage_options)
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
return super().load_checkpoint(self, str(checkpoint_path))
else:
assert checkpoint_path.is_dir()
global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt')
local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt')
global_states['optimizer_states'] = local_optimizer_states
return global_states
# Meant to work with Apex's DistributeFusedAdam
from typing import Any, Callable, Dict, List, Optional, Union
from pathlib import Path
import types
import torch
from torch.optim.optimizer import Optimizer
from torch.optim import LBFGS
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import _PATH
# from lightning_lite.utilities.types import _PATH
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
def optimizer_step( # type: ignore[override]
self,
model: "pl.LightningModule",
optimizer,
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return NativeMixedPrecisionPlugin.optimizer_step(
self, optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs
)
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()
# HACK: we don't call self.scaler.unscale_ here. This is because DistributedFusedAdam
# optimizer internally takes the scale into account.
# If we call unscale_ here, it would be equivalent to unscaling the gradients twice.
# Not unscaling has the side-effect that the NormMonitor callback will report the
# gradient norm to be much larger than reality.
# # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
# self.scaler.unscale_(optimizer)
# This will call gradient clipping
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result
def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm."""
# DistributedFusedAdam wants list, not generator
# Gradients have not be scaled, so we need to scale up the clip_val
if self.scaler is not None:
clip_val *= self.scaler.get_scale()
return optimizer.clip_grad_norm(clip_val)
class DDPStrategyZero2(DDPStrategy):
"""To use Apex's DistributedFusedAdam, we need to shard the optimizer states when
saving/loading checkpoints.
"""
strategy_name = "ddp_zero2"
def __init__(
self,
*args,
precision_plugin: Optional[PrecisionPlugin] = DistAdamNativeMixedPrecisionPlugin,
# precision_plugin: Optional[PrecisionPlugin] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(
*args, precision_plugin=precision_plugin, **kwargs
)
@property
def precision_plugin(self) -> PrecisionPlugin:
return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin()
@precision_plugin.setter
def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None:
self._precision_plugin = precision_plugin
# https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance
self._precision_plugin.optimizer_step = types.MethodType(
DistAdamNativeMixedPrecisionPlugin.optimizer_step, self._precision_plugin
)
self._precision_plugin.clip_grad_by_norm = types.MethodType(
DistAdamNativeMixedPrecisionPlugin.clip_grad_by_norm, self._precision_plugin
)
def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
if isinstance(optimizer, DistributedFusedAdam):
return optimizer.state_dict(gather_on_root=False)
else:
return optimizer.state_dict()
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
"""
filepath = Path(filepath)
filepath.mkdir(parents=True, exist_ok=True)
local_optimizer_states = checkpoint.pop('optimizer_states')
if self.is_global_zero:
self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt',
storage_options=storage_options)
self.checkpoint_io.save_checkpoint(local_optimizer_states,
filepath / f'{self.global_rank:03d}_optim_states.pt',
storage_options=storage_options)
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
return super().load_checkpoint(self, str(checkpoint_path))
else:
assert checkpoint_path.is_dir()
global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt')
local_optimizer_states = self.checkpoint_io.load_checkpoint(
checkpoint_path / f'{self.global_rank:03d}_optim_states.pt',
map_location='cuda'
)
global_states['optimizer_states'] = local_optimizer_states
return global_states
# Copied from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import torch
def init_distributed(cuda):
"""
Initializes distributed backend.
:param cuda: (bool) if True initializes nccl backend, if False initializes
gloo backend
"""
world_size = int(os.environ.get('WORLD_SIZE', 1))
distributed = (world_size > 1)
if distributed:
backend = 'nccl' if cuda else 'gloo'
torch.distributed.init_process_group(backend=backend,
init_method='env://')
assert torch.distributed.is_initialized()
return distributed
def barrier():
"""
Call torch.distributed.barrier() if distritubed is in use
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
def get_rank():
"""
Gets distributed rank or returns zero if distributed is not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
return rank
def get_world_size():
"""
Gets total number of distributed workers or returns one if distributed is
not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def all_reduce_item(value, op='sum'):
"""
All-reduces single scalar value if distributed is in use
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if op == 'sum' or op == 'mean':
dop = torch.distributed.ReduceOp.SUM
elif op == 'min':
dop = torch.distributed.ReduceOp.MIN
elif op == 'max':
dop = torch.distributed.ReduceOp.MAX
elif op == 'product':
dop = torch.distributed.ReduceOp.PRODUCT
else:
raise RuntimeError('Unsupported reduce op')
backend = torch.distributed.get_backend()
if backend == torch.distributed.Backend.NCCL:
device = torch.device('cuda')
elif backend == torch.distributed.Backend.GLOO:
device = torch.device('cpu')
else:
raise RuntimeError('Unsupported distributed backend')
tensor = torch.tensor(value, device=device)
torch.distributed.all_reduce(tensor, dop)
if op == 'mean':
tensor /= get_world_size()
ret = tensor.item()
else:
ret = value
return ret
@contextmanager
def sync_workers():
"""
Yields distributed rank and synchronizes all workers on exit.
"""
rank = get_rank()
yield rank
barrier()
# Copied from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py
from __future__ import division
from __future__ import unicode_literals
from typing import Iterable, Optional
import weakref
import copy
import contextlib
import torch
def to_float_maybe(x):
return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x
# Partially based on:
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter` (typically from
`model.parameters()`).
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
def __init__(
self,
parameters: Iterable[torch.nn.Parameter],
decay: float,
use_num_updates: bool = True
):
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
parameters = list(parameters)
self.shadow_params = [to_float_maybe(p.clone().detach())
for p in parameters if p.requires_grad]
self.collected_params = None
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
# if the model goes out of scope but the ExponentialMovingAverage
# is kept, no references to the model or its parameters will be
# maintained, and the model will be cleaned up.
self._params_refs = [weakref.ref(p) for p in parameters]
def _get_parameters(
self,
parameters: Optional[Iterable[torch.nn.Parameter]]
) -> Iterable[torch.nn.Parameter]:
if parameters is None:
parameters = [p() for p in self._params_refs]
if any(p is None for p in parameters):
raise ValueError(
"(One of) the parameters with which this "
"ExponentialMovingAverage "
"was initialized no longer exists (was garbage collected);"
" please either provide `parameters` explicitly or keep "
"the model to which they belong from being garbage "
"collected."
)
return parameters
else:
parameters = list(parameters)
if len(parameters) != len(self.shadow_params):
raise ValueError(
"Number of parameters passed as argument is different "
"from number of shadow parameters maintained by this "
"ExponentialMovingAverage"
)
return parameters
def update(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(
decay,
(1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay
if parameters[0].device != self.shadow_params[0].device:
self.to(device=parameters[0].device)
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param)
def copy_to(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
def store(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored. If `None`, the parameters of with which this
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.collected_params = [
param.clone()
for param in parameters
if param.requires_grad
]
def restore(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
if self.collected_params is None:
raise RuntimeError(
"This ExponentialMovingAverage has no `store()`ed weights "
"to `restore()`"
)
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)
@contextlib.contextmanager
def average_parameters(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
):
r"""
Context manager for validation/inference with averaged parameters.
Equivalent to:
ema.store()
ema.copy_to()
try:
...
finally:
ema.restore()
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.store(parameters)
self.copy_to(parameters)
try:
yield
finally:
self.restore(parameters)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
for p in self.shadow_params
]
if self.collected_params is not None:
self.collected_params = [
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
for p in self.collected_params
]
return
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
"decay": self.decay,
"num_updates": self.num_updates,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params
}
def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the ExponentialMovingAverage state.
Args:
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)
self.decay = state_dict["decay"]
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
self.collected_params = state_dict["collected_params"]
if self.collected_params is not None:
assert isinstance(self.collected_params, list), \
"collected_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.collected_params
), "collected_params must all be Tensors"
assert len(self.collected_params) == len(self.shadow_params), \
"collected_params and shadow_params had different lengths"
if len(self.shadow_params) == len(self._params_refs):
# Consistent with torch.optim.Optimizer, cast things to consistent
# device and dtype with the parameters
params = [p() for p in self._params_refs]
# If parameters have been garbage collected, just load the state
# we were given without change.
if not any(p is None for p in params):
# ^ parameter references are still good
for i, p in enumerate(params):
self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to(
device=p.device, dtype=p.dtype
))
if self.collected_params is not None:
self.collected_params[i] = self.collected_params[i].to(
device=p.device, dtype=p.dtype
)
else:
raise ValueError(
"Tried to `load_state_dict()` with the wrong number of "
"parameters in the saved state."
)
# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py
import torch
try:
from deepspeed.profiling.flops_profiler import get_model_profile
has_deepspeed_profiling = True
except ImportError as e:
has_deepspeed_profiling = False
try:
from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table
from fvcore.nn import ActivationCountAnalysis
has_fvcore_profiling = True
except ImportError as e:
FlopCountAnalysis = None
ActivationCountAnalysis = None
has_fvcore_profiling = False
def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch.float32,
batch_size=1, detailed=False):
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
flops, macs, params = get_model_profile(
model=model,
args=torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype),
print_profile=detailed, # prints the model graph with the measured profile attached to each module
detailed=detailed, # print the detailed profile
warm_up=10, # the number of warm-ups before measuring the time of each module
as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
output_file=None, # path to the output file. If None, the profiler prints to stdout.
ignore_modules=None) # the list of modules to ignore in the profiling
return macs, 0 # no activation count in DS
def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.float32, max_depth=4,
batch_size=1, detailed=False, force_cpu=False):
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_input = torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype)
fca = FlopCountAnalysis(model, example_input)
aca = ActivationCountAnalysis(model, example_input)
if detailed:
print(flop_count_table(fca, max_depth=max_depth))
return fca, fca.total(), aca, aca.total()
import collections
import math
import os
import pathlib
import re
import pynvml
pynvml.nvmlInit()
def systemGetDriverVersion():
return pynvml.nvmlSystemGetDriverVersion()
def deviceGetCount():
return pynvml.nvmlDeviceGetCount()
class device:
# assume nvml returns list of 64 bit ints
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
def __init__(self, device_idx):
super().__init__()
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
def getName(self):
return pynvml.nvmlDeviceGetName(self.handle)
def getCpuAffinity(self):
affinity_string = ''
for j in pynvml.nvmlDeviceGetCpuAffinity(
self.handle, device._nvml_affinity_elements
):
# assume nvml returns list of 64 bit ints
affinity_string = '{:064b}'.format(j) + affinity_string
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is in 0th element of list
ret = [i for i, e in enumerate(affinity_list) if e != 0]
return ret
def set_socket_affinity(gpu_id):
dev = device(gpu_id)
affinity = dev.getCpuAffinity()
os.sched_setaffinity(0, affinity)
def set_single_affinity(gpu_id):
dev = device(gpu_id)
affinity = dev.getCpuAffinity()
os.sched_setaffinity(0, affinity[:1])
def set_single_unique_affinity(gpu_id, nproc_per_node):
devices = [device(i) for i in range(nproc_per_node)]
socket_affinities = [dev.getCpuAffinity() for dev in devices]
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
affinities = []
assigned = []
for socket_affinity in socket_affinities:
for core in socket_affinity:
if core not in assigned:
affinities.append([core])
assigned.append(core)
break
os.sched_setaffinity(0, affinities[gpu_id])
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode):
device_ids = [device(i) for i in range(nproc_per_node)]
socket_affinities = [dev.getCpuAffinity() for dev in device_ids]
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
socket_affinities_to_device_ids = collections.defaultdict(list)
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
devices_per_group = len(device_ids)
cores_per_device = len(socket_affinity) // devices_per_group
for group_id, device_id in enumerate(device_ids):
if device_id == gpu_id:
if mode == 'interleaved':
affinity = list(socket_affinity[group_id::devices_per_group])
elif mode == 'continuous':
affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device])
else:
raise RuntimeError('Unknown set_socket_unique_affinity mode')
# reintroduce siblings
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
os.sched_setaffinity(0, affinity)
def get_thread_siblings_list():
path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list'
thread_siblings_list = []
pattern = re.compile(r'(\d+)\D(\d+)')
for fname in pathlib.Path(path[0]).glob(path[1:]):
with open(fname) as f:
content = f.read().strip()
res = pattern.findall(content)
if res:
pair = tuple(map(int, res[0]))
thread_siblings_list.append(pair)
return thread_siblings_list
def set_affinity(gpu_id, nproc_per_node, mode='socket'):
if mode == 'socket':
set_socket_affinity(gpu_id)
elif mode == 'single':
set_single_affinity(gpu_id)
elif mode == 'single_unique':
set_single_unique_affinity(gpu_id, nproc_per_node)
elif mode == 'socket_unique_interleaved':
set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved')
elif mode == 'socket_unique_continuous':
set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous')
else:
raise RuntimeError('Unknown affinity mode')
affinity = os.sched_getaffinity(0)
return affinity
import logging
import warnings
from typing import List, Sequence
import pytorch_lightning as pl
import rich.syntax
import rich.tree
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
class LoggingContext:
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
# implicit return of None => don't swallow exceptions
def get_logger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
def extras(config: DictConfig) -> None:
"""A couple of optional utilities, controlled by main config file:
- disabling warnings
- forcing debug friendly configuration
- verifying experiment name is set when running in experiment mode
Modifies DictConfig in place.
Args:
config (DictConfig): Configuration composed by Hydra.
"""
log = get_logger(__name__)
# disable python warnings if <config.ignore_warnings=True>
if config.get("ignore_warnings"):
log.info("Disabling python warnings! <config.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# verify experiment name is set when running in experiment mode
if config.get("experiment_mode") and not config.get("name"):
log.info(
"Running in experiment mode without the experiment name specified! "
"Use `python run.py mode=exp name=experiment_name`"
)
log.info("Exiting...")
exit()
# force debugger friendly configuration if <config.trainer.fast_dev_run=True>
# debuggers don't like GPUs and multiprocessing
if config.trainer.get("fast_dev_run"):
log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
if config.trainer.get("gpus"):
config.trainer.gpus = 0
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):
config.datamodule.num_workers = 0
@rank_zero_only
def print_config(
config: DictConfig,
fields: Sequence[str] = (
"trainer",
"model",
"datamodule",
"train",
"eval",
"callbacks",
"logger",
"seed",
"name",
),
resolve: bool = True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
fields (Sequence[str], optional): Determines which main fields from config will
be printed and in what order.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
with open("config_tree.txt", "w") as fp:
rich.print(tree, file=fp)
def finish(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Makes sure everything closed properly."""
# without this sweeps with wandb logger might crash!
for lg in logger:
if isinstance(lg, pl.loggers.wandb.WandbLogger):
import wandb
wandb.finish()
import os
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import pytest
import torch
import dotenv
from src.datamodules.language_modeling_hf import LMDataModule
# load environment variables from `.env` file if it exists
# recursively searches for `.env` in all folders starting from work dir
dotenv.load_dotenv(override=True)
def div_up(x: int, y: int) -> int:
return (x + y - 1) // y
# https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
def num_cpu_cores():
try:
import psutil
return psutil.cpu_count(logical=False)
except ImportError:
return len(os.sched_getaffinity(0))
class TestLMDataModule:
def test_wikitext2(self):
batch_size = 7
dataset_name = 'wikitext'
dataset_config_name = 'wikitext-2-raw-v1'
data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
cache_dir = data_dir / 'wikitext-2' / 'cache'
max_length = 1024
datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
dataset_config_name=dataset_config_name,
max_length=max_length, cache_dir=cache_dir,
add_eos=False, batch_size=batch_size, num_workers=4)
datamodule.prepare_data()
datamodule.setup(stage='fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage='test')
test_loader = datamodule.test_dataloader()
train_len = 2391884
val_len = 247289
test_len = 283287
assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
for loader in [train_loader, val_loader, test_loader]:
x, y = next(iter(loader))
assert x.dim() == 2
assert x.shape == (batch_size, max_length)
assert x.dtype == torch.long
assert torch.allclose(x[:, 1:], y[:, :-1])
def test_wikitext103(self):
batch_size = 7
dataset_name = 'wikitext'
dataset_config_name = 'wikitext-103-raw-v1'
data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
cache_dir = data_dir / 'wikitext-103' / 'cache'
max_length = 1024
datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
dataset_config_name=dataset_config_name,
max_length=max_length, cache_dir=cache_dir,
add_eos=False, batch_size=batch_size, num_workers=4)
datamodule.prepare_data()
datamodule.setup(stage='fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage='test')
test_loader = datamodule.test_dataloader()
train_len = 117920140
val_len = 247289
test_len = 283287
assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
for loader in [train_loader, val_loader, test_loader]:
x, y = next(iter(loader))
assert x.dim() == 2
assert x.shape == (batch_size, max_length)
assert x.dtype == torch.long
assert torch.allclose(x[:, 1:], y[:, :-1])
def test_openwebtext(self):
batch_size = 8
dataset_name = 'openwebtext'
dataset_config_name = None
data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
cache_dir = data_dir / 'openwebtext' / 'cache'
max_length = 1024
datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
dataset_config_name=dataset_config_name,
max_length=max_length, cache_dir=cache_dir,
add_eos=True, batch_size=batch_size,
num_workers=num_cpu_cores() // 2)
datamodule.prepare_data()
datamodule.setup(stage='fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage='test')
test_loader = datamodule.test_dataloader()
train_len = 9035582198
val_len = 4434897
test_len = 4434897
assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
for loader in [train_loader, val_loader, test_loader]:
x, y = next(iter(loader))
assert x.dim() == 2
assert x.shape == (batch_size, max_length)
assert x.dtype == torch.long
assert torch.allclose(x[:, 1:], y[:, :-1])
def test_lambada(self):
batch_size = 8
dataset_name = 'lambada'
dataset_config_name = None
data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
cache_dir = data_dir / 'lambada' / 'cache'
max_length = 1024
datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
dataset_config_name=dataset_config_name,
max_length=max_length, cache_dir=cache_dir,
add_eos=True, batch_size=batch_size,
num_workers=64)
datamodule.prepare_data()
datamodule.setup(stage='fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage='test')
test_loader = datamodule.test_dataloader()
train_len = 9035582198
val_len = 4434897
test_len = 4434897
assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
for loader in [train_loader, val_loader, test_loader]:
x, y = next(iter(loader))
assert x.dim() == 2
assert x.shape == (batch_size, max_length)
assert x.dtype == torch.long
assert torch.allclose(x[:, 1:], y[:, :-1])
def test_the_pile(self):
batch_size = 8
dataset_name = 'the_pile'
dataset_config_name = None
data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
cache_dir = data_dir / 'the_pile' / 'cache'
max_length = 2048
# Dataset is too large to fit into memory, need to use disk for concatenation
datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
dataset_config_name=dataset_config_name,
max_length=max_length, cache_dir=cache_dir,
add_eos=True, batch_size=batch_size,
num_workers=num_cpu_cores() // 2, use_shmem=False)
datamodule.prepare_data()
datamodule.setup(stage='fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage='test')
test_loader = datamodule.test_dataloader()
train_len = 374337375694
val_len = 383326395
test_len = 373297018
assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
for loader in [train_loader, val_loader, test_loader]:
x, y = next(iter(loader))
assert x.dim() == 2
assert x.shape == (batch_size, max_length)
assert x.dtype == torch.long
assert torch.allclose(x[:, 1:], y[:, :-1])
def test_pg19(self):
batch_size = 8
dataset_name = 'pg19'
dataset_config_name = None
data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
cache_dir = data_dir / 'pg19' / 'cache'
max_length = 2048
# Dataset is too large to fit into memory, need to use disk for concatenation
datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
dataset_config_name=dataset_config_name,
max_length=max_length, cache_dir=cache_dir,
add_eos=True, batch_size=batch_size,
num_workers=num_cpu_cores() // 2)
datamodule.prepare_data()
datamodule.setup(stage='fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
datamodule.setup(stage='test')
test_loader = datamodule.test_dataloader()
train_len = 3066544128
val_len = 4653056
test_len = 10584064
assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
for loader in [train_loader, val_loader, test_loader]:
x, y = next(iter(loader))
assert x.dim() == 2
assert x.shape == (batch_size, max_length)
assert x.dtype == torch.long
assert torch.allclose(x[:, 1:], y[:, :-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment