"classification/ops_dcnv3/modules/dcnv3.py" did not exist on "c4552f794aab15e56a00ccb06747e3fa6b8bec38"
Commit 8a915496 authored by Guolin Ke's avatar Guolin Ke
Browse files

first commit

parent 5cf7df97
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from unicore import metrics
from unicore.losses import UnicoreLoss, register_loss
@register_loss("masked_lm")
class MaskedLMLoss(UnicoreLoss):
def __init__(self, task):
super().__init__(task)
self.padding_idx = task.dictionary.pad()
def forward(self, model, sample, reduce=True):
masked_tokens = sample["target"].ne(self.padding_idx)
sample_size = masked_tokens.int().sum()
masked_tokens = torch.where(
masked_tokens.any(),
masked_tokens,
masked_tokens.new([True]),
)
logits = model(**sample["net_input"], masked_tokens=masked_tokens)
target = sample['target']
if masked_tokens is not None:
target = target[masked_tokens]
loss = F.nll_loss(
F.log_softmax(logits, dim=-1, dtype=torch.float32),
target,
ignore_index=self.padding_idx,
reduction='sum',
)
logging_output = {
"loss": loss.data,
"bsz": sample["target"].size(0),
"sample_size": sample_size,
"seq_len": sample["target"].size(1) * sample["target"].size(0),
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs, split='valid') -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
bsz = sum(log.get("bsz", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
seq_len = sum(log.get("seq_len", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"seq_len", seq_len / bsz, 1, round=3
)
@staticmethod
def logging_outputs_can_be_summed(is_train) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Any, Dict, List
from unicore import metrics, utils
from torch.nn.modules.loss import _Loss
class UnicoreLoss(_Loss):
def __init__(self, task):
super().__init__()
self.task = task
if task is not None:
self.args = task.args
if hasattr(task, "target_dictionary"):
tgt_dict = task.target_dictionary
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
@classmethod
def add_args(cls, parser):
pass
@classmethod
def build_loss(cls, args, task):
"""Construct a loss from command-line args."""
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
if (
p.kind == p.POSITIONAL_ONLY
or p.kind == p.VAR_POSITIONAL
or p.kind == p.VAR_KEYWORD
):
# we haven't implemented inference for these argument types,
# but PRs welcome :)
raise NotImplementedError("{} not supported".format(p.kind))
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
if p.name == "task":
init_args["task"] = task
elif p.name == "args":
init_args["args"] = args
elif hasattr(args, p.name):
init_args[p.name] = getattr(args, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
raise NotImplementedError(
"Unable to infer Loss arguments, please implement "
"{}.build_loss".format(cls.__name__)
)
return cls(**init_args)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError
@staticmethod
def logging_outputs_can_be_summed(is_train: bool) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from .distributed_unicore_model import DistributedUnicoreModel
from .unicore_model import (
BaseUnicoreModel,
)
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_NAME_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [
"BaseUnicoreModel",
"DistributedUnicoreModel",
]
def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
def register_model(name):
"""
New model types can be added to unicore with the :func:`register_model`
function decorator.
For example::
@register_model("lstm")
class LSTM(UnicoreEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseUnicoreModel` interface.
Typically you will extend :class:`UnicoreEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`UnicoreLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model
"""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError("Cannot register duplicate model ({})".format(name))
if not issubclass(cls, BaseUnicoreModel):
raise ValueError("Model ({}: {}) must extend BaseUnicoreModel".format(name, cls.__name__))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""
New model architectures can be added to unicore with the
:func:`register_model_architecture` function decorator. After registration,
model architectures can be selected with the ``--arch`` command-line
argument.
For example::
@register_model_architecture("lstm", "lstm_luong_wmt_en_de")
def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1000)
(...)
The decorated function should take a single argument *args*, which is a
:class:`argparse.Namespace` of arguments parsed from the command-line. The
decorated function should modify these arguments in-place to match the
desired architecture.
Args:
model_name (str): the name of the Model (Model must already be
registered)
arch_name (str): the name of the model architecture (``--arch``)
"""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError("Cannot register model architecture for unknown model type ({})".format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError("Cannot register duplicate model architecture ({})".format(arch_name))
if not callable(fn):
raise ValueError("Model architecture must be callable ({})".format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name)
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
model_name = file[:file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("unicore.models." + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group("Named architectures")
group_archs.add_argument("--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name])
group_args = parser.add_argument_group("Additional command-line arguments")
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from unicore.distributed import (
ModuleProxyWrapper, LegacyDistributedDataParallel
)
logger = logging.getLogger(__name__)
def DistributedUnicoreModel(args, model, process_group, device):
"""
Wrap a *model* to support distributed data parallel training.
This is similar to the built-in DistributedDataParallel, but allows
additional configuration of the DistributedDataParallel class to
use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args:
args (argparse.Namespace): unicore args
model (BaseUnicoreModel): model to wrap
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
device: device to move model to
"""
assert isinstance(model, nn.Module)
if args.ddp_backend in {"c10d", "pytorch_ddp"}:
wrapped_model = DistributedDataParallel(
module=model.to(device),
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=args.broadcast_buffers,
bucket_cap_mb=args.bucket_cap_mb,
process_group=process_group,
find_unused_parameters=args.find_unused_parameters,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend in {'apex'}:
import apex
wrapped_model = apex.parallel.DistributedDataParallel(
module=model.to(device)
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend in {"no_c10d", "legacy_ddp"}:
wrapped_model = LegacyDistributedDataParallel(
module=model.to(device),
buffer_size=2 ** 28,
process_group=process_group,
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)
return wrapped_model
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Base classes for various unicore models.
"""
import logging
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
class BaseUnicoreModel(nn.Module):
"""Base class for unicore models."""
def __init__(self):
super().__init__()
@classmethod
def add_args(cls, parser):
"""Add model-specific arguments to the parser."""
pass
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError("Model must implement the build_model method")
def extract_features(self, *args, **kwargs):
"""Similar to *forward* but only return features."""
return self(*args, **kwargs)
def load_state_dict(
self,
state_dict,
strict=True,
model_args = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`.
"""
return super().load_state_dict(state_dict, strict)
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
def _apply(m):
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
self.apply(_apply)
"""isort:skip_file"""
from .layer_norm import LayerNorm
from .softmax_dropout import softmax_dropout
from .multihead_attention import SelfMultiheadAttention, CrossMultiheadAttention
from .transformer_encoder_layer import TransformerEncoderLayer
from .transformer_encoder import TransformerEncoder, init_bert_params, relative_position_bucket
from .transformer_decoder_layer import TransformerDecoderLayer
from .transformer_decoder import TransformerDecoder
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F
import unicore_fused_layernorm
import unicore_fused_layernorm_backward_gamma_beta
class FusedLayerNormFastFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = unicore_fused_layernorm.forward(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input = unicore_fused_layernorm.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
grad_weight, grad_bias = unicore_fused_layernorm_backward_gamma_beta.backward(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
assert elementwise_affine
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormFastFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine=True'.format(**self.__dict__)
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
import torch
from torch import Tensor, nn
from .softmax_dropout import softmax_dropout
class SelfMultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
dropout=0.1,
bias=True,
scaling_factor=1,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = (self.head_dim * scaling_factor) ** -0.5
self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(
self,
query,
key_padding_mask: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
return_attn: bool=False,
) -> Tensor:
bsz, tgt_len, embed_dim = query.size()
assert embed_dim == self.embed_dim
q, k, v = self.in_proj(query).chunk(3, dim=-1)
q = (
q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim) * self.scaling
)
if k is not None:
k = (
k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
)
if v is not None:
v = (
v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_bias is not None:
attn_weights += attn_bias
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
attn = self.out_proj(attn)
if not return_attn:
return attn
else:
return attn, attn_weights, attn_probs
class CrossMultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
dropout=0.1,
bias=True,
scaling_factor=1,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = (self.head_dim * scaling_factor) ** -0.5
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def forward(
self,
query,
key,
value,
key_padding_mask: Optional[Tensor] = None,
attn_bias: Optional[Tensor] = None,
) -> Tensor:
bsz, tgt_len, embed_dim = query.size()
assert embed_dim == self.embed_dim
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q = (
q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim) * self.scaling
)
if k is not None:
k = (
k.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
)
if v is not None:
v = (
v.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attn_bias is not None:
attn_weights += attn_bias
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
attn = self.out_proj(attn)
return attn
\ No newline at end of file
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import unicore_fused_softmax_dropout
import torch.nn.functional as F
class SoftmaxDropoutFast(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, inputs, dropout_prob):
# don't use ctx.save_for_backward to save dropout_prob
# allocating space for a tensor is time-consuming
dropout_results, dropout_mask, softmax_results = unicore_fused_softmax_dropout.forward(is_training,
inputs, dropout_prob, None)
if is_training:
ctx.dropout_prob = dropout_prob
ctx.save_for_backward(softmax_results, dropout_mask)
return dropout_results
@staticmethod
def backward(ctx, grad_output):
softmax_results, dropout_mask = ctx.saved_tensors
dropout_prob = ctx.dropout_prob
grad_output = grad_output.contiguous()
grad_input = unicore_fused_softmax_dropout.backward(grad_output, softmax_results,
dropout_mask, dropout_prob)
return None, grad_input, None
def softmax_dropout(input, dropout_prob, is_training=True):
input = input.contiguous()
input_size = input.size()
input = input.view(-1, input_size[-2], input_size[-1])
if input.is_cuda and input.shape[-1] <= 2048:
return SoftmaxDropoutFast.apply(is_training, input, dropout_prob).view(*input_size)
else:
return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training).view(*input_size)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import TransformerDecoderLayer, LayerNorm
from .transformer_encoder import relative_position_bucket
def fill_with_neg_inf(t):
return t.fill_(float("-inf"))
def bulid_future_mask(seq_len):
return torch.triu(
fill_with_neg_inf(torch.zeros([seq_len, seq_len])), 1
)
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layers: int = 6,
embed_dim: int = 768,
ffn_embed_dim: int = 3072,
attention_heads: int = 8,
emb_dropout: float = 0.1,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.0,
max_seq_len: int = 256,
activation_fn: str = "gelu",
rel_pos: bool = True,
rel_pos_bins: int = 32,
max_rel_pos: int = 128,
post_ln: bool = False,
auto_regressive: bool = True,
) -> None:
super().__init__()
self.emb_dropout = emb_dropout
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
self.attention_heads = attention_heads
self.emb_layer_norm = LayerNorm(self.embed_dim)
self.auto_regressive = auto_regressive
if self.auto_regressive:
self._future_mask = bulid_future_mask(self.max_seq_len)
else:
self._future_mask = None
if not post_ln:
self.final_layer_norm = LayerNorm(self.embed_dim)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
embed_dim=self.embed_dim,
ffn_embed_dim=ffn_embed_dim,
attention_heads=attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,
)
for _ in range(decoder_layers)
]
)
self.rel_pos = rel_pos
if self.rel_pos:
assert rel_pos_bins % 2 == 0
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads)
seq_len = self.max_seq_len
context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
self.rp_bucket = relative_position_bucket(
relative_position,
num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos
)
self.rp_bucket -= self.rp_bucket.min()
def get_rel_pos_bias(self, x):
# Assume the input is ordered. If your input token is permuted, you may need to update this accordingly
if self.rp_bucket.device != x.device:
self.rp_bucket = self.rp_bucket.to(x.device)
seq_len = x.size(1)
rp_bucket = self.rp_bucket[:seq_len, :seq_len]
values = F.embedding(rp_bucket, self.relative_attention_bias.weight)
values = values.permute([2, 0, 1])
return values.contiguous()
def get_future_mask(self, x, attn_mask):
if not self.auto_regressive:
return attn_mask
if self._future_mask.device != x.device:
self._future_mask = self._future_mask.to(x.device)
if self._future_mask.dtype != x.dtype:
self._future_mask = self._future_mask.type_as(x)
if attn_mask is None:
ret = self._future_mask[:x.size(1), :x.size(1)]
ret = ret.contiguous().unsqueeze(0).repeat(x.size(0)*self.attention_heads, 1, 1)
else:
assert list(attn_mask.size()) == [x.size(0) * self.attention_heads, x.size(1), x.size(1)]
return attn_mask + self._future_mask[:x.size(1), :x.size(1)]
def forward(
self,
emb,
encoder_out: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
encoder_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_len = emb.size(1)
x = self.emb_layer_norm(emb)
x = F.dropout(x, p=self.emb_dropout, training=self.training)
# account for padding while computing the representation
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.size(0), 1, 1) if self.rel_pos else None
if attn_mask is None:
attn_mask = rel_pos_bias
elif rel_pos_bias is not None:
attn_mask += rel_pos_bias
if self.auto_regressive:
attn_mask = self.get_future_mask(x, attn_mask)
if attn_mask is not None and padding_mask is not None:
# merge key_padding_mask and attn_mask
attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len)
attn_mask.masked_fill_(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
)
attn_mask = attn_mask.view(-1, seq_len, seq_len)
padding_mask = None
for layer in self.layers:
x = layer(x, encoder_out=encoder_out, padding_mask=padding_mask, attn_bias=attn_mask,
encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask)
if self.final_layer_norm != None:
x = self.final_layer_norm(x)
return x
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from unicore import utils
from torch import nn
from . import LayerNorm, SelfMultiheadAttention, CrossMultiheadAttention
class TransformerDecoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embed_dim: int = 768,
ffn_embed_dim: int = 3072,
attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.0,
activation_fn: str = "gelu",
post_ln = False,
) -> None:
super().__init__()
# Initialize parameters
self.embed_dim = embed_dim
self.attention_heads = attention_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = SelfMultiheadAttention(
self.embed_dim,
attention_heads,
dropout=attention_dropout,
)
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = CrossMultiheadAttention(
self.embed_dim,
attention_heads,
dropout=attention_dropout,
)
# layer norm associated with the self attention layer
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim)
self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.post_ln = post_ln
def forward(
self,
x: torch.Tensor,
encoder_out:torch.Tensor=None,
attn_bias: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
encoder_attn_bias: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
residual = x
if not self.post_ln:
x = self.self_attn_layer_norm(x)
# new added
x = self.self_attn(
query=x,
key_padding_mask=padding_mask,
attn_bias=attn_bias,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.self_attn_layer_norm(x)
if encoder_out is not None:
residual = x
if not self.post_ln:
x = self.encoder_attn_layer_norm(x)
x = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
attn_bias=encoder_attn_bias,
)
#x = self.dropout_module(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.encoder_attn_layer_norm(x)
residual = x
if not self.post_ln:
x = self.final_layer_norm(x)
x = self.fc1(x)
x = self.activation_fn(x)
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.final_layer_norm(x)
return x
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import TransformerEncoderLayer, LayerNorm
def init_bert_params(module):
if not getattr(module, 'can_global_init', True):
return
def normal_(data):
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
sign = torch.sign(relative_position)
num_buckets //= 2
n = torch.abs(relative_position)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
max_bucket_val = num_buckets - 1 - max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + torch.ceil(
torch.log(n.float() / max_exact) / math.log((max_distance - 1) / max_exact) * (max_bucket_val)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret = torch.where(is_small, n, val_if_large) * sign
return ret
class TransformerEncoder(nn.Module):
def __init__(
self,
encoder_layers: int = 6,
embed_dim: int = 768,
ffn_embed_dim: int = 3072,
attention_heads: int = 8,
emb_dropout: float = 0.1,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.0,
max_seq_len: int = 256,
activation_fn: str = "gelu",
rel_pos: bool = True,
rel_pos_bins: int = 32,
max_rel_pos: int = 128,
post_ln: bool = False,
) -> None:
super().__init__()
self.emb_dropout = emb_dropout
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
self.attention_heads = attention_heads
self.emb_layer_norm = LayerNorm(self.embed_dim)
if not post_ln:
self.final_layer_norm = LayerNorm(self.embed_dim)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
embed_dim=self.embed_dim,
ffn_embed_dim=ffn_embed_dim,
attention_heads=attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,
)
for _ in range(encoder_layers)
]
)
self.rel_pos = rel_pos
if self.rel_pos:
assert rel_pos_bins % 2 == 0
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads)
seq_len = self.max_seq_len
context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
self.rp_bucket = relative_position_bucket(
relative_position,
num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos
)
self.rp_bucket -= self.rp_bucket.min()
def get_rel_pos_bias(self, x):
# Assume the input is ordered. If your input token is permuted, you may need to update this accordingly
if self.rp_bucket.device != x.device:
self.rp_bucket = self.rp_bucket.to(x.device)
seq_len = x.size(1)
rp_bucket = self.rp_bucket[:seq_len, :seq_len]
values = F.embedding(rp_bucket, self.relative_attention_bias.weight)
values = values.permute([2, 0, 1])
return values.contiguous()
def forward(
self,
emb: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_len = emb.size(1)
x = self.emb_layer_norm(emb)
x = F.dropout(x, p=self.emb_dropout, training=self.training)
# account for padding while computing the representation
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.size(0), 1, 1) if self.rel_pos else None
if attn_mask is None:
attn_mask = rel_pos_bias
elif rel_pos_bias is not None:
attn_mask += rel_pos_bias
if attn_mask is not None and padding_mask is not None:
# merge key_padding_mask and attn_mask
attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len)
attn_mask.masked_fill_(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf")
)
attn_mask = attn_mask.view(-1, seq_len, seq_len)
padding_mask = None
for layer in self.layers:
x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask)
if self.final_layer_norm != None:
x = self.final_layer_norm(x)
return x
\ No newline at end of file
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from unicore import utils
from torch import nn
from . import LayerNorm, SelfMultiheadAttention
class TransformerEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embed_dim: int = 768,
ffn_embed_dim: int = 3072,
attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.0,
activation_fn: str = "gelu",
post_ln = False,
) -> None:
super().__init__()
# Initialize parameters
self.embed_dim = embed_dim
self.attention_heads = attention_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.activation_dropout = activation_dropout
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = SelfMultiheadAttention(
self.embed_dim,
attention_heads,
dropout=attention_dropout,
)
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim)
self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.post_ln = post_ln
def forward(
self,
x: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
return_attn: bool=False,
) -> torch.Tensor:
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
residual = x
if not self.post_ln:
x = self.self_attn_layer_norm(x)
# new added
x = self.self_attn(
query=x,
key_padding_mask=padding_mask,
attn_bias=attn_bias,
return_attn=return_attn,
)
if return_attn:
x, attn_weights, attn_probs = x
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.self_attn_layer_norm(x)
residual = x
if not self.post_ln:
x = self.final_layer_norm(x)
x = self.fc1(x)
x = self.activation_fn(x)
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if self.post_ln:
x = self.final_layer_norm(x)
if not return_attn:
return x
else:
return x, attn_weights, attn_probs
\ No newline at end of file
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
logger = logging.getLogger(__name__)
class NanDetector:
"""
Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name
"""
def __init__(self, model, forward=True, backward=True):
self.bhooks = []
self.fhooks = []
self.forward = forward
self.backward = backward
self.named_parameters = list(model.named_parameters())
self.reset()
for name, mod in model.named_modules():
mod.__module_name = name
self.add_hooks(mod)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
# Dump out all model gnorms to enable better debugging
norm = {}
gradients = {}
for name, param in self.named_parameters:
if param.grad is not None:
grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32)
norm[name] = grad_norm.item()
if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any():
gradients[name] = param.grad.data
if len(gradients) > 0:
logger.info("Detected nan/inf grad norm, dumping norms...")
logger.info(f"norms: {norm}")
logger.info(f"gradients: {gradients}")
self.close()
def add_hooks(self, module):
if self.forward:
self.fhooks.append(module.register_forward_hook(self.fhook_fn))
if self.backward:
self.bhooks.append(module.register_backward_hook(self.bhook_fn))
def reset(self):
self.has_printed_f = False
self.has_printed_b = False
def _detect(self, tensor, name, backward):
err = None
if (
torch.is_floating_point(tensor)
# single value tensors (like the loss) will not provide much info
and tensor.numel() >= 2
):
with torch.no_grad():
if torch.isnan(tensor).any():
err = "NaN"
elif torch.isinf(tensor).any():
err = "Inf"
if err is not None:
err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}"
return err
def _apply(self, module, inp, x, backward):
if torch.is_tensor(x):
if isinstance(inp, tuple) and len(inp) > 0:
inp = inp[0]
err = self._detect(x, module.__module_name, backward)
if err is not None:
if torch.is_tensor(inp) and not backward:
err += (
f" input max: {inp.max().item()}, input min: {inp.min().item()}"
)
has_printed_attr = "has_printed_b" if backward else "has_printed_f"
logger.warning(err)
setattr(self, has_printed_attr, True)
elif isinstance(x, dict):
for v in x.values():
self._apply(module, inp, v, backward)
elif isinstance(x, list) or isinstance(x, tuple):
for v in x:
self._apply(module, inp, v, backward)
def fhook_fn(self, module, inp, output):
if not self.has_printed_f:
self._apply(module, inp, output, backward=False)
def bhook_fn(self, module, inp, output):
if not self.has_printed_b:
self._apply(module, inp, output, backward=True)
def close(self):
for hook in self.fhooks + self.bhooks:
hook.remove()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import importlib
import os
from unicore import registry
from unicore.optim.unicore_optimizer import ( # noqa
UnicoreOptimizer,
)
from unicore.optim.fp16_optimizer import FP16Optimizer
__all__ = [
"UnicoreOptimizer",
"FP16Optimizer",
]
(
_build_optimizer,
register_optimizer,
OPTIMIZER_REGISTRY
) = registry.setup_registry("--optimizer", base_class=UnicoreOptimizer, default='adam')
def build_optimizer(args, params, *extra_args, **extra_kwargs):
if all(isinstance(p, dict) for p in params):
params = [t for p in params for t in p.values()]
params = list(filter(lambda p: p.requires_grad, params))
return _build_optimizer(args, params, *extra_args, **extra_kwargs)
# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("unicore.optim." + file_name)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.optim
from . import UnicoreOptimizer, register_optimizer
@register_optimizer("adadelta")
class Adadelta(UnicoreOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
help='coefficient used for computing a running average of squared gradients')
parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
help='term added to the denominator to improve numerical stability')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0],
"rho": self.args.adadelta_rho,
"eps": self.args.adadelta_eps,
"weight_decay": self.args.weight_decay,
}
@property
def supports_flat_params(self):
return True
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.optim
from . import UnicoreOptimizer, register_optimizer
@register_optimizer("adagrad")
class Adagrad(UnicoreOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0],
"weight_decay": self.args.weight_decay,
}
@property
def supports_flat_params(self):
return False
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from collections.abc import Collection
from typing import List
import torch
import torch.optim
from unicore.optim import UnicoreOptimizer, register_optimizer
from unicore.optim.fused_adam import get_fused_adam_class
logger = logging.getLogger(__name__)
@register_optimizer("adam")
class UnicoreAdam(UnicoreOptimizer):
"""Adam optimizer for unicore.
Important note: this optimizer corresponds to the "AdamW" variant of
Adam in its weight decay behavior. As such, it is most closely
analogous to torch.optim.AdamW from PyTorch.
"""
def __init__(self, args, params):
super().__init__(args)
fused_adam_cls = get_fused_adam_class()
use_fused_adam = (
not getattr(args, "use_old_adam", False)
and fused_adam_cls is not None
and torch.cuda.is_available()
)
if use_fused_adam:
logger.info("using FusedAdam")
self._optimizer = fused_adam_cls(params, **self.optimizer_config)
else:
self._optimizer = Adam(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0]
if isinstance(self.args.lr, Collection)
else self.args.lr,
"betas": eval(self.args.adam_betas),
"eps": self.args.adam_eps,
"weight_decay": self.args.weight_decay,
}
class Adam(torch.optim.Optimizer):
r"""Implements Adam algorithm.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
(see https://arxiv.org/abs/1711.05101)
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
):
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
)
super(Adam, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
amsgrad = group.get("amsgrad", False)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
p_data_fp32
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
if group["weight_decay"] != 0:
p_data_fp32.add_(
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
class DynamicLossScaler(object):
def __init__(
self,
init_scale=2.0 ** 15,
scale_factor=2.0,
scale_window=2000,
tolerance=0.0,
threshold=None,
min_loss_scale=1e-4,
):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self.threshold = threshold
self._iter = 0
self._last_overflow_iter = -1
self._last_rescale_iter = -1
self._overflows_since_rescale = 0
self.min_loss_scale = min_loss_scale
def scale(self, outputs):
return self.loss_scale * outputs
def update(self):
if (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._last_rescale_iter = self._iter
self._iter += 1
def _decrease_loss_scale(self):
self.loss_scale /= self.scale_factor
if self.threshold is not None:
self.loss_scale = max(self.loss_scale, self.threshold)
def check_overflow(self, grad_norm):
# detect inf and nan
if grad_norm == float("inf") or grad_norm != grad_norm:
# overflow has occured
prev_scale = self.loss_scale
iter_since_rescale = self._iter - self._last_rescale_iter
self._last_overflow_iter = self._iter
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self._decrease_loss_scale()
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
if self.loss_scale <= self.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
self.loss_scale = prev_scale
raise FloatingPointError(
(
"Minimum loss scale reached ({}). Your loss is probably exploding. "
"Try lowering the learning rate, using gradient clipping or "
"increasing the batch size."
).format(self.min_loss_scale)
)
self._iter += 1
raise OverflowError("setting loss scale to: " + str(self.loss_scale))
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import torch
from unicore import optim
from unicore import utils
from .dynamic_loss_scaler import DynamicLossScaler
class _FP16OptimizerMixin(object):
def __init__(self, args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(args, **kwargs)
self._multiply_factor = 1.0
self.bf16_sr = getattr(args, "bf16_sr", False)
@classmethod
def build_fp32_params(cls, args, params):
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
devices = [torch.cuda.current_device()]
fp32_params = {}
for device in devices:
device_param_size = total_param_size
device_params = params
fp32_params[device] = (
device_params[0].new(0).float().new(device_param_size)
)
offset = 0
for p in device_params:
numel = p.data.numel()
fp32_params[device][offset : offset + numel].copy_(p.data.view(-1))
offset += numel
fp32_params[device] = torch.nn.Parameter(fp32_params[device])
fp32_params[device].grad = fp32_params[device].data.new(
device_param_size
)
return fp32_params
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
if self.scaler is not None:
state_dict["loss_scale"] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if "loss_scale" in state_dict and self.scaler is not None:
self.scaler.loss_scale = state_dict["loss_scale"]
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
if self.scaler is not None:
loss = self.scaler.scale(loss)
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self):
with torch.no_grad():
if self._needs_sync:
devices = list(self.fp32_params.keys())
device_params_dict = defaultdict(list)
for p in self.fp16_params:
if p.requires_grad:
device_params_dict[p.device.index].append(p)
for device in devices:
device_params = device_params_dict[device]
offset = 0
for p in device_params:
numel = p.numel()
if p.grad is not None:
self.fp32_params[device].grad.data[
offset : offset + numel
].copy_(p.grad.data.view(-1))
offset += numel
self._needs_sync = False
def _add_fp16_grads_to_fp32(self, mul=0.0):
with torch.no_grad():
devices = list(self.fp32_params.keys())
device_params_dict = defaultdict(list)
for p in self.fp16_params:
if p.requires_grad:
device_params_dict[p.device.index].append(p)
for device in devices:
device_params = device_params_dict[device]
offset = 0
for p in device_params:
numel = p.numel()
if p.grad is not None:
self.fp32_params[device].grad.data[
offset : offset + numel
] += mul * p.grad.data.float().view(-1)
p.grad = None
offset += numel
self._needs_sync = False
def _sync_fp32_params_to_fp16(self):
# copy FP32 params back into FP16 model
devices = list(self.fp32_params.keys())
device_params_dict = defaultdict(list)
for p in self.fp16_params:
device_params_dict[p.device.index].append(p)
for device in devices:
device_params = device_params_dict[device]
offset = 0
for p in device_params:
numel = p.data.numel()
u = self.fp32_params[device].data[offset : offset + numel].view_as(p.data)
if self.bf16_sr and p.dtype == torch.bfloat16:
utils.fp32_to_bf16_sr(u, p)
else:
p.data.copy_(u)
offset += numel
def _unscale_grads(self):
self._sync_fp16_grads_to_fp32()
if (
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor
# is 1.0). At the same time, we want to avoid the device-to-host
# transfer by comparing it to 1.0. Since _multiply_factor starts as
# a Python float, we roughly assume that if it's a tensor then it's
# probably not =1.0 anymore and we do the multiplication. Otherwise
# we can safely check the value without a D2H transfer.
torch.is_tensor(self._multiply_factor)
or self._multiply_factor != 1.0
):
self.fp32_optimizer.multiply_grads(self._multiply_factor)
self._multiply_factor = 1.0
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._multiply_factor *= c
else:
# gradients already synced to fp32 parameters, update it directly
self.fp32_optimizer.multiply_grads(c)
def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm."""
if max_norm <= 0.0:
return 0.0
grad_norm = self._multiply_factor * utils.clip_grad_norm_(self.fp16_params, 0, aggregate_norm_fn)
# grad_norm = 1.0
if grad_norm > max_norm > 0.0:
clip_coef = max_norm / (grad_norm + 1e-6)
else:
clip_coef = 1.0
self._add_fp16_grads_to_fp32(mul=clip_coef)
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
0, aggregate_norm_fn=aggregate_norm_fn,
)
if self.scaler is not None:
if grad_norm > max_norm > 0.0:
self._multiply_factor *= max_norm / grad_norm
self.scaler.check_overflow(grad_norm)
elif max_norm > 0.0:
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
self._multiply_factor *= clip_coef
return grad_norm
def step(self, closure=None, groups=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False):
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
else:
self._unscale_grads()
self.fp32_optimizer.step(closure, groups=groups)
if self.scaler is not None:
self.scaler.update()
self._sync_fp32_params_to_fp16()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for p in self.fp16_params:
p.grad = None
if torch.is_tensor(self.fp32_params):
self.fp32_params.grad.zero_()
elif isinstance(self.fp32_params, dict):
for fp32_params in self.fp32_params.values():
fp32_params.grad.zero_()
else:
raise RuntimeError("self.fp32_params must be a tensor or dict")
self._needs_sync = False
if self.scaler is not None:
self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
else:
self._multiply_factor = 1.0
class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params, **kwargs):
super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.allreduce_fp32_grad = getattr(args, "allreduce_fp32_grad", False)
if getattr(args, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1:
raise ValueError(
"--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule"
)
data_parallel_size = int(
args.distributed_world_size
)
scale_window = int(
2 ** 14 / data_parallel_size / args.update_freq[0]
)
else:
scale_window = args.fp16_scale_window
if not getattr(args, "bf16", False):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
min_loss_scale=args.min_loss_scale,
)
else:
# disable loss scaling for bfloat16
self.scaler = None
@classmethod
def build_optimizer(cls, args, params, **kwargs):
"""
Args:
args : unicore args
params (iterable): iterable of parameters to optimize
"""
flatten = not getattr(args, "fp16_no_flatten_grads", False)
assert flatten
fp32_params = cls.build_fp32_params(args, params)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params, **kwargs)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@optimizer.setter
def optimizer(self, optimizer):
self.fp32_optimizer.optimizer = optimizer
@property
def lr_scheduler(self):
return getattr(self.fp32_optimizer, "lr_scheduler", None)
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def all_reduce_grads(self, module):
if self.allreduce_fp32_grad and hasattr(module, "all_reduce_params"):
self._sync_fp16_grads_to_fp32()
with torch.no_grad():
params = [p for p in self.fp32_optimizer.params]
module.all_reduce_params(params)
else:
self.fp32_optimizer.all_reduce_grads(module)
@property
def supports_flat_params(self):
return self.fp32_optimizer.supports_flat_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment