Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
import glob
import os
from typing import List, Optional, Tuple
import logging
import numpy as np
import torchvision.transforms.functional as TF
import PIL
from PIL import Image
from torchvision.datasets import VisionDataset
logger = logging.getLogger(__name__)
class PathDataset(VisionDataset):
def __init__(
self,
root: List[str],
loader: None = None,
transform: Optional[str] = None,
extra_transform: Optional[str] = None,
mean: Optional[List[float]] = None,
std: Optional[List[float]] = None,
):
super().__init__(root=root)
PIL.Image.MAX_IMAGE_PIXELS = 256000001
self.files = []
for folder in self.root:
self.files.extend(
sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
)
self.files.extend(
sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
)
self.transform = transform
self.extra_transform = extra_transform
self.mean = mean
self.std = std
self.loader = loader
logger.info(f"loaded {len(self.files)} samples from {root}")
assert (mean is None) == (std is None)
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
path = self.files[idx]
if self.loader is not None:
return self.loader(path), None
img = Image.open(path).convert("RGB")
if self.transform is not None:
img = self.transform(img)
img = TF.to_tensor(img)
if self.mean is not None and self.std is not None:
img = TF.normalize(img, self.mean, self.std)
return img, None
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from omegaconf import OmegaConf
from fairseq.criterions.model_criterion import ModelCriterionConfig
from fairseq.dataclass.configs import FairseqConfig
from tasks import ImageClassificationConfig, ImagePretrainingConfig
from models.data2vec_image_classification import (
Data2VecImageClassificationConfig,
Data2VecImageClassificationModel,
)
from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel
def get_parser():
parser = argparse.ArgumentParser(
description="convert beit checkpoint into data2vec - vision checkpoint"
)
# fmt: off
parser.add_argument('checkpoint', help='checkpoint to convert')
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint')
parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade')
parser.add_argument('--inception_norms', action='store_true', default=False)
# fmt: on
return parser
def update_checkpoint(model_dict, prefix, is_nested):
replace_paths = {
"cls_token": "model.cls_emb" if is_nested else "cls_emb",
"patch_embed": "model.patch_embed" if is_nested else "patch_embed",
"mask_token": "mask_emb",
}
starts_with = {
"patch_embed.proj": "model.patch_embed.conv"
if is_nested
else "patch_embed.conv",
"lm_head": "final_proj",
"fc_norm": "fc_norm",
"head": "head",
}
partial = {
"mlp.fc1": "mlp.0",
"mlp.fc2": "mlp.2",
}
for k in list(model_dict.keys()):
for sw, r in starts_with.items():
if k.startswith(sw):
replace_paths[k] = k.replace(sw, r)
for p, r in partial.items():
if p in k:
replace_paths[k] = prefix + k.replace(p, r)
if prefix != "":
for k in list(model_dict.keys()):
if k not in replace_paths:
replace_paths[k] = prefix + k
for k in list(model_dict.keys()):
if k in replace_paths:
model_dict[replace_paths[k]] = model_dict[k]
if k != replace_paths[k]:
del model_dict[k]
return model_dict
def main():
parser = get_parser()
args = parser.parse_args()
cp = torch.load(args.checkpoint, map_location="cpu")
cfg = FairseqConfig(
criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]),
)
if args.type == "image_classification":
cfg.task = ImageClassificationConfig(
_name="image_classification",
data=".",
)
if args.inception_norms:
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
cfg.task.normalization_std = [0.5, 0.5, 0.5]
cfg.model = Data2VecImageClassificationConfig(
_name="data2vec_image_classification",
)
cfg.model.pretrained_model_args = FairseqConfig(
model=Data2VecVisionConfig(
_name="data2vec_vision", shared_rel_pos_bias=False
),
task=ImagePretrainingConfig(
_name="image_pretraining",
),
)
cfg = OmegaConf.create(cfg)
state = {
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
"model": cp["module"],
"best_loss": None,
"optimizer": None,
"extra_state": {},
}
model = Data2VecImageClassificationModel(cfg.model)
model.load_state_dict(
update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True),
strict=True,
)
elif args.type == "vision":
cfg.task = ImagePretrainingConfig(
_name="image_pretraining",
data=".",
)
if args.inception_norms:
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
cfg.task.normalization_std = [0.5, 0.5, 0.5]
cfg.model = Data2VecVisionConfig(
_name="data2vec_vision",
)
cfg = OmegaConf.create(cfg)
state = {
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
"model": cp["model"],
"best_loss": None,
"optimizer": None,
"extra_state": {},
}
model = Data2VecVisionModel(cfg.model)
model.load_state_dict(
update_checkpoint(state["model"], prefix="encoder.", is_nested=False),
strict=True,
)
else:
raise Exception("unsupported type " + args.type)
print(state["cfg"], state.keys())
torch.save(state, args.output)
if __name__ == "__main__":
main()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import II, MISSING, open_dict
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import (
BaseFairseqModel,
register_model,
)
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
from fairseq.modules import TransposeLast
from fairseq.tasks import FairseqTask
logger = logging.getLogger(__name__)
@dataclass
class AudioClassificationConfig(FairseqDataclass):
model_path: str = field(
default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
)
no_pretrained_weights: bool = field(
default=False, metadata={"help": "if true, does not load pretrained weights"}
)
dropout_input: float = field(
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"},
)
final_dropout: float = field(
default=0.0,
metadata={"help": "dropout after transformer and before final projection"},
)
dropout: float = field(
default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
)
attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights inside wav2vec 2.0 model"
},
)
activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
},
)
# masking
apply_mask: bool = field(
default=False, metadata={"help": "apply masking during fine-tuning"}
)
mask_length: int = field(
default=10, metadata={"help": "repeat the mask indices multiple times"}
)
mask_prob: float = field(
default=0.5,
metadata={
"help": "probability of replacing a token with mask (normalized by length)"
},
)
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static", metadata={"help": "how to choose masks"}
)
mask_other: float = field(
default=0,
metadata={
"help": "secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indices"
},
)
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"}
)
mask_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
require_same_masks: bool = field(
default=True,
metadata={
"help": "whether to number of masked timesteps must be the same across all "
"examples in a batch"
},
)
mask_dropout: float = field(
default=0.0,
metadata={"help": "percent of masks to unmask for each sample"},
)
# channel masking
mask_channel_length: int = field(
default=10, metadata={"help": "length of the mask for features (channels)"}
)
mask_channel_prob: float = field(
default=0.0, metadata={"help": "probability of replacing a feature with 0"}
)
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static",
metadata={"help": "how to choose mask length for channel masking"},
)
mask_channel_other: float = field(
default=0,
metadata={
"help": "secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indicesh"
},
)
no_mask_channel_overlap: bool = field(
default=False, metadata={"help": "whether to allow channel masks to overlap"}
)
freeze_finetune_updates: int = field(
default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
)
feature_grad_mult: float = field(
default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
)
layerdrop: float = field(
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
)
mask_channel_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
mask_channel_before: bool = False
normalize: bool = II("task.normalize")
data: str = II("task.data")
# this holds the loaded wav2vec args
d2v_args: Any = None
offload_activations: bool = field(
default=False, metadata={"help": "offload_activations"}
)
min_params_to_wrap: int = field(
default=int(1e8),
metadata={
"help": "minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
},
)
checkpoint_activations: bool = field(
default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
)
ddp_backend: str = II("distributed_training.ddp_backend")
prediction_mode: str = "lin_softmax"
eval_prediction_mode: Optional[str] = None
conv_kernel: int = -1
conv_stride: int = 1
two_convs: bool = False
extreme_factor: float = 1.0
conv_feature_layers: Optional[str] = field(
default=None,
metadata={
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
mixup_prob: float = 1.0
source_mixup: float = -1
same_mixup: bool = True
label_mixup: bool = False
gain_mode: str = "none"
@register_model("audio_classification", dataclass=AudioClassificationConfig)
class AudioClassificationModel(BaseFairseqModel):
def __init__(self, cfg: AudioClassificationConfig, num_classes):
super().__init__()
self.apply_mask = cfg.apply_mask
self.cfg = cfg
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"require_same_masks": getattr(cfg, "require_same_masks", True),
"mask_dropout": getattr(cfg, "mask_dropout", 0),
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_before": cfg.mask_channel_before,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
"checkpoint_activations": cfg.checkpoint_activations,
"offload_activations": cfg.offload_activations,
"min_params_to_wrap": cfg.min_params_to_wrap,
"mixup": -1,
}
if cfg.conv_feature_layers is not None:
arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
if cfg.d2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.model_path, arg_overrides
)
d2v_args = state.get("cfg", None)
if d2v_args is None:
d2v_args = convert_namespace_to_omegaconf(state["args"])
d2v_args.criterion = None
d2v_args.lr_scheduler = None
cfg.d2v_args = d2v_args
logger.info(d2v_args)
else:
state = None
d2v_args = cfg.d2v_args
model_normalized = d2v_args.task.get(
"normalize", d2v_args.model.get("normalize", False)
)
assert cfg.normalize == model_normalized, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for both pre-training and here"
)
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
with open_dict(d2v_args):
d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
d2v_args.task.data = cfg.data
task = tasks.setup_task(d2v_args.task)
model = task.build_model(d2v_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
d = d2v_args.model.encoder_embed_dim
self.d2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
for p in self.parameters():
p.param_group = "pretrained"
if cfg.prediction_mode == "proj_avg_proj":
self.proj = nn.Linear(d, d * 2)
self.proj2 = nn.Linear(d * 2, num_classes)
for p in self.proj.parameters():
p.param_group = "projection"
for p in self.proj2.parameters():
p.param_group = "projection"
elif self.cfg.prediction_mode == "summary_proj":
self.proj = nn.Linear(d // 3, num_classes)
for p in self.proj.parameters():
p.param_group = "projection"
elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
self.proj = nn.Sequential(
TransposeLast(),
nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
TransposeLast(),
)
for p in self.proj.parameters():
p.param_group = "projection"
elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
self.proj = nn.Sequential(
TransposeLast(),
nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
TransposeLast(),
nn.GELU(),
nn.Linear(d, num_classes),
)
for p in self.proj.parameters():
p.param_group = "projection"
else:
self.proj = nn.Linear(d, num_classes)
for p in self.proj.parameters():
p.param_group = "projection"
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
"""Build a new model instance."""
assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
return cls(cfg, len(task.labels))
def load_model_weights(self, state, model, cfg):
if cfg.ddp_backend == "fully_sharded":
from fairseq.distributed import FullyShardedDataParallel
for name, module in model.named_modules():
if "encoder.layers" in name and len(name.split(".")) == 3:
# Only for layers, we do a special handling and load the weights one by one
# We dont load all weights together as that wont be memory efficient and may
# cause oom
new_dict = {
k.replace(name + ".", ""): v
for (k, v) in state["model"].items()
if name + "." in k
}
assert isinstance(module, FullyShardedDataParallel)
with module.summon_full_params():
module.load_state_dict(new_dict, strict=True)
module._reset_lazy_init()
# Once layers are loaded, filter them out and load everything else.
r = re.compile("encoder.layers.\d.")
filtered_list = list(filter(r.match, state["model"].keys()))
new_big_dict = {
k: v for (k, v) in state["model"].items() if k not in filtered_list
}
model.load_state_dict(new_big_dict, strict=False)
else:
if "_ema" in state["model"]:
del state["model"]["_ema"]
model.load_state_dict(state["model"], strict=False)
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
if fs == 16000:
n_fft = 2048
elif fs == 44100:
n_fft = 4096
else:
raise Exception("Invalid fs {}".format(fs))
stride = n_fft // 2
def a_weight(fs, n_fft, min_db=-80.0):
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
freq_sq = np.power(freq, 2)
freq_sq[0] = 1.0
weight = 2.0 + 20.0 * (
2 * np.log10(12194)
+ 2 * np.log10(freq_sq)
- np.log10(freq_sq + 12194 ** 2)
- np.log10(freq_sq + 20.6 ** 2)
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
)
weight = np.maximum(weight, min_db)
return weight
gain = []
for i in range(0, len(sound) - n_fft + 1, stride):
if mode == "RMSE":
g = np.mean(sound[i : i + n_fft] ** 2)
elif mode == "A_weighting":
spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
power_spec = np.abs(spec) ** 2
a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
g = np.sum(a_weighted_spec)
else:
raise Exception("Invalid mode {}".format(mode))
gain.append(g)
gain = np.array(gain)
gain = np.maximum(gain, np.power(10, min_db / 10))
gain_db = 10 * np.log10(gain)
return gain_db
# adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
if fs == 16000:
n_fft = 2048
elif fs == 44100:
n_fft = 4096
else:
raise Exception("Invalid fs {}".format(fs))
if mode == "A_weighting":
if not hasattr(self, f"a_weight"):
self.a_weight = {}
if fs not in self.a_weight:
def a_weight(fs, n_fft, min_db=-80.0):
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
freq_sq = freq ** 2
freq_sq[0] = 1.0
weight = 2.0 + 20.0 * (
2 * np.log10(12194)
+ 2 * np.log10(freq_sq)
- np.log10(freq_sq + 12194 ** 2)
- np.log10(freq_sq + 20.6 ** 2)
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
)
weight = np.maximum(weight, min_db)
return weight
self.a_weight[fs] = torch.from_numpy(
np.power(10, a_weight(fs, n_fft, min_db) / 10)
).to(device=sound.device)
sound = sound.unfold(-1, n_fft, n_fft // 2)
if mode == "RMSE":
sound = sound ** 2
g = sound.mean(-1)
elif mode == "A_weighting":
w = torch.hann_window(n_fft, device=sound.device) * sound
spec = torch.fft.rfft(w)
power_spec = spec.abs() ** 2
a_weighted_spec = power_spec * self.a_weight[fs]
g = a_weighted_spec.sum(-1)
else:
raise Exception("Invalid mode {}".format(mode))
gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
gain_db = 10 * torch.log10(gain)
return gain_db
def forward(self, source, padding_mask, label=None, **kwargs):
if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
with torch.no_grad():
mixed_source = source
mix_mask = None
if self.cfg.mixup_prob < 1:
mix_mask = (
torch.empty((source.size(0),), device=source.device)
.bernoulli_(self.cfg.mixup_prob)
.bool()
)
mixed_source = source[mix_mask]
r = (
torch.FloatTensor(
1 if self.cfg.same_mixup else mixed_source.size(0)
)
.uniform_(max(1e-6, self.cfg.source_mixup), 1)
.to(dtype=source.dtype, device=source.device)
)
mixup_perm = torch.randperm(source.size(0))
s2 = source[mixup_perm]
if self.cfg.gain_mode == "none":
p = r.unsqueeze(-1)
if mix_mask is not None:
s2 = s2[mix_mask]
else:
if self.cfg.gain_mode == "naive_rms":
G1 = source.pow(2).mean(dim=-1).sqrt()
else:
G1, _ = self.compute_gain_torch(
source, mode=self.cfg.gain_mode
).max(-1)
G1 = G1.to(dtype=source.dtype)
G2 = G1[mixup_perm]
if mix_mask is not None:
G1 = G1[mix_mask]
G2 = G2[mix_mask]
s2 = s2[mix_mask]
p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
p = p.unsqueeze(-1)
mixed = (p * mixed_source) + (1 - p) * s2
if mix_mask is None:
source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
else:
source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
if label is not None and self.cfg.label_mixup:
r = r.unsqueeze(-1)
if mix_mask is None:
label = label * r + (1 - r) * label[mixup_perm]
else:
label[mix_mask] = (
label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
)
d2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
res = self.d2v_model.extract_features(**d2v_args)
x = res["x"]
padding_mask = res["padding_mask"]
if padding_mask is not None:
x[padding_mask] = 0
x = self.final_dropout(x)
if self.training or (
self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
):
prediction_mode = self.cfg.prediction_mode
else:
prediction_mode = self.cfg.eval_prediction_mode
if prediction_mode == "average_before":
x = x.mean(dim=1)
if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
x = self.proj(x)
logits = True
if prediction_mode == "lin_softmax":
x = F.logsigmoid(x.float())
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
x = x.clamp(max=0)
x = x - torch.log(-(torch.expm1(x)))
elif prediction_mode == "extremized_odds":
x = x.float().sum(dim=1)
x = x * self.cfg.extreme_factor
elif prediction_mode == "average_before":
x = x.float()
elif prediction_mode == "average":
x = x.float().mean(dim=1)
elif prediction_mode == "average_sigmoid":
x = torch.sigmoid(x.float())
x = x.mean(dim=1)
logits = False
elif prediction_mode == "max":
x, _ = x.float().max(dim=1)
elif prediction_mode == "max_sigmoid":
x = torch.sigmoid(x.float())
x, _ = x.float().max(dim=1)
logits = False
elif prediction_mode == "proj_avg_proj":
x = x.mean(dim=1)
x = self.proj2(x)
elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
x = self.d2v_model.summary(
x, padding_mask, proj=prediction_mode == "summary_proj"
)
x = x.type_as(source)
x = self.proj(x)
elif prediction_mode == "cls":
x = x[:,0]
x = self.proj(x)
else:
raise Exception(f"unknown prediction mode {prediction_mode}")
if label is None:
return torch.sigmoid(x) if logits else x
x = torch.nan_to_num(x)
if logits:
loss = F.binary_cross_entropy_with_logits(
x, label.float(), reduction="none"
)
else:
loss = F.binary_cross_entropy(x, label.float(), reduction="none")
result = {
"losses": {
"main": loss,
},
"sample_size": label.sum(),
}
if not self.training:
result["_predictions"] = torch.sigmoid(x) if logits else x
result["_targets"] = label
return result
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass, field
from typing import Optional, Callable
from functools import partial
import numpy as np
from omegaconf import II
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from examples.data2vec.data.modality import Modality
from examples.data2vec.models.modalities.base import (
MaskSeed,
D2vModalityConfig,
ModalitySpecificEncoder,
get_annealed_rate,
)
from examples.data2vec.models.modalities.modules import (
D2vDecoderConfig,
AltBlock,
Decoder1d,
)
from examples.data2vec.models.modalities.audio import (
D2vAudioConfig,
AudioEncoder,
)
from examples.data2vec.models.modalities.images import (
D2vImageConfig,
ImageEncoder,
)
from examples.data2vec.models.modalities.text import (
D2vTextConfig,
TextEncoder,
)
logger = logging.getLogger(__name__)
@dataclass
class D2vModalitiesConfig(FairseqDataclass):
audio: D2vAudioConfig = D2vAudioConfig()
image: D2vImageConfig = D2vImageConfig()
text: D2vTextConfig = D2vTextConfig()
@dataclass
class Data2VecMultiConfig(FairseqDataclass):
loss_beta: float = field(
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
)
loss_scale: Optional[float] = field(
default=None,
metadata={
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
depth: int = 8
start_drop_path_rate: float = 0
end_drop_path_rate: float = 0
num_heads: int = 12
norm_eps: float = 1e-6
norm_affine: bool = True
encoder_dropout: float = 0.1
post_mlp_drop: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.0
dropout_input: float = 0.0
layerdrop: float = 0.0
embed_dim: int = 768
mlp_ratio: float = 4
layer_norm_first: bool = False
average_top_k_layers: int = field(
default=8, metadata={"help": "how many layers to average"}
)
end_of_block_targets: bool = False
clone_batch: int = 1
layer_norm_target_layer: bool = False
batch_norm_target_layer: bool = False
instance_norm_target_layer: bool = False
instance_norm_targets: bool = False
layer_norm_targets: bool = False
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
ema_same_dtype: bool = True
log_norms: bool = True
ema_end_decay: float = field(
default=0.9999, metadata={"help": "final ema decay rate"}
)
# when to finish annealing ema decay rate
ema_anneal_end_step: int = II("optimization.max_update")
ema_encoder_only: bool = field(
default=True,
metadata={
"help": "whether to momentum update only the shared transformer encoder"
},
)
max_update: int = II("optimization.max_update")
modalities: D2vModalitiesConfig = D2vModalitiesConfig()
shared_decoder: Optional[D2vDecoderConfig] = None
min_target_var: float = field(
default=0.1, metadata={"help": "stop training if target var falls below this"}
)
min_pred_var: float = field(
default=0.01,
metadata={"help": "stop training if prediction var falls below this"},
)
supported_modality: Optional[Modality] = None
mae_init: bool = False
seed: int = II("common.seed")
skip_ema: bool = False
cls_loss: float = 0
recon_loss: float = 0
d2v_loss: float = 1
decoder_group: bool = False
@register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
class Data2VecMultiModel(BaseFairseqModel):
def make_modality_encoder(
self,
cfg: D2vModalityConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases,
task,
) -> ModalitySpecificEncoder:
if cfg.type == Modality.AUDIO:
enc_cls = AudioEncoder
elif cfg.type == Modality.IMAGE:
enc_cls = ImageEncoder
elif cfg.type == Modality.TEXT:
enc_cls = TextEncoder
if hasattr(task, "text_task"):
task = task.text_task
else:
raise Exception(f"unsupported modality {cfg.type}")
return enc_cls(
cfg,
embed_dim,
make_block,
norm_layer,
layer_norm_first,
alibi_biases,
task,
)
def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
super().__init__()
self.cfg = cfg
self.modalities = modalities
self.task = task
make_layer_norm = partial(
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
)
def make_block(drop_path, dim=None, heads=None):
return AltBlock(
cfg.embed_dim if dim is None else dim,
cfg.num_heads if heads is None else heads,
cfg.mlp_ratio,
qkv_bias=True,
drop=cfg.encoder_dropout,
attn_drop=cfg.attention_dropout,
mlp_drop=cfg.activation_dropout,
post_mlp_drop=cfg.post_mlp_drop,
drop_path=drop_path,
norm_layer=make_layer_norm,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
)
self.alibi_biases = {}
self.modality_encoders = nn.ModuleDict()
for mod in self.modalities:
mod_cfg = getattr(cfg.modalities, mod.name.lower())
enc = self.make_modality_encoder(
mod_cfg,
cfg.embed_dim,
make_block,
make_layer_norm,
cfg.layer_norm_first,
self.alibi_biases,
task,
)
self.modality_encoders[mod.name] = enc
self.ema = None
self.average_top_k_layers = cfg.average_top_k_layers
self.loss_beta = cfg.loss_beta
self.loss_scale = cfg.loss_scale
self.dropout_input = nn.Dropout(cfg.dropout_input)
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
self.norm = None
if cfg.layer_norm_first:
self.norm = make_layer_norm(cfg.embed_dim)
if self.cfg.mae_init:
self.apply(self._init_weights)
else:
from fairseq.modules.transformer_sentence_encoder import init_bert_params
self.apply(init_bert_params)
for mod_enc in self.modality_encoders.values():
mod_enc.reset_parameters()
if not skip_ema:
self.ema = self.make_ema_teacher(cfg.ema_decay)
self.shared_decoder = (
Decoder1d(cfg.shared_decoder, cfg.embed_dim)
if self.cfg.shared_decoder is not None
else None
)
if self.shared_decoder is not None:
self.shared_decoder.apply(self._init_weights)
self.recon_proj = None
if cfg.recon_loss > 0:
self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
for pn, p in self.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
if cfg.decoder_group and "decoder" in pn:
p.param_group = "decoder"
self.num_updates = 0
def _init_weights(self, m):
try:
from apex.normalization import FusedLayerNorm
fn = FusedLayerNorm
except:
fn = nn.LayerNorm
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
@torch.no_grad()
def make_ema_teacher(self, ema_decay):
ema_config = EMAModuleConfig(
ema_decay=ema_decay,
ema_fp32=True,
log_norms=self.cfg.log_norms,
add_missing_params=False,
)
model_copy = self.make_target_model()
return EMAModule(
model_copy,
ema_config,
copy_model=False,
)
def make_target_model(self):
logger.info("making target model")
model_copy = Data2VecMultiModel(
self.cfg, self.modalities, skip_ema=True, task=self.task
)
if self.cfg.ema_encoder_only:
model_copy = model_copy.blocks
for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
p_t.data.copy_(p_s.data)
else:
for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
p_t.data.copy_(p_s.data)
for mod_enc in model_copy.modality_encoders.values():
mod_enc.decoder = None
if not mod_enc.modality_cfg.ema_local_encoder:
mod_enc.local_encoder = None
mod_enc.project_features = None
model_copy.requires_grad_(False)
return model_copy
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is not None and (
(self.num_updates == 0 and num_updates > 1)
or self.num_updates >= num_updates
):
pass
elif self.training and self.ema is not None:
ema_weight_decay = None
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay, weight_decay=ema_weight_decay)
if self.ema.get_decay() < 1:
self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
self.num_updates = num_updates
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
if self.ema is not None:
state[prefix + "_ema"] = self.ema.fp32_params
return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
k = prefix + "_ema"
if self.ema is not None:
assert k in state_dict
self.ema.restore(state_dict[k], True)
del state_dict[k]
elif k in state_dict:
del state_dict[k]
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
@classmethod
def build_model(cls, cfg: Data2VecMultiConfig, task=None):
"""Build a new model instance."""
if task is None or not hasattr(task, "supported_modalities"):
modalities = (
[cfg.supported_modality]
if cfg.supported_modality is not None
else [
Modality.AUDIO,
Modality.IMAGE,
Modality.TEXT,
]
)
else:
modalities = task.supported_modalities
return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
def forward(
self,
source,
target=None,
id=None,
mode=None,
padding_mask=None,
mask=True,
features_only=False,
force_remove_masked=False,
remove_extra_tokens=True,
precomputed_mask=None,
):
if mode is None:
assert self.cfg.supported_modality is not None
mode = self.cfg.supported_modality
if isinstance(mode, Modality):
mode = mode.name
feature_extractor = self.modality_encoders[mode]
mask_seeds = None
if id is not None:
mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
extractor_out = feature_extractor(
source,
padding_mask,
mask,
remove_masked=not features_only or force_remove_masked,
clone_batch=self.cfg.clone_batch if not features_only else 1,
mask_seeds=mask_seeds,
precomputed_mask=precomputed_mask,
)
x = extractor_out["x"]
encoder_mask = extractor_out["encoder_mask"]
masked_padding_mask = extractor_out["padding_mask"]
masked_alibi_bias = extractor_out.get("alibi_bias", None)
alibi_scale = extractor_out.get("alibi_scale", None)
if self.dropout_input is not None:
x = self.dropout_input(x)
layer_results = []
for i, blk in enumerate(self.blocks):
if (
not self.training
or self.cfg.layerdrop == 0
or (np.random.random() > self.cfg.layerdrop)
):
ab = masked_alibi_bias
if ab is not None and alibi_scale is not None:
scale = (
alibi_scale[i]
if alibi_scale.size(0) > 1
else alibi_scale.squeeze(0)
)
ab = ab * scale.type_as(ab)
x, lr = blk(
x,
padding_mask=masked_padding_mask,
alibi_bias=ab,
)
if features_only:
layer_results.append(lr)
if self.norm is not None:
x = self.norm(x)
if features_only:
if remove_extra_tokens:
x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
if masked_padding_mask is not None:
masked_padding_mask = masked_padding_mask[
:, feature_extractor.modality_cfg.num_extra_tokens :
]
return {
"x": x,
"padding_mask": masked_padding_mask,
"layer_results": layer_results,
"mask": encoder_mask,
}
xs = []
if self.shared_decoder is not None:
dx = self.forward_decoder(
x,
feature_extractor,
self.shared_decoder,
encoder_mask,
)
xs.append(dx)
if feature_extractor.decoder is not None:
dx = self.forward_decoder(
x,
feature_extractor,
feature_extractor.decoder,
encoder_mask,
)
xs.append(dx)
orig_x = x
assert len(xs) > 0
p = next(self.ema.model.parameters())
device = x.device
dtype = x.dtype
ema_device = p.device
ema_dtype = p.dtype
if not self.cfg.ema_same_dtype:
dtype = ema_dtype
if ema_device != device or ema_dtype != dtype:
logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
self.ema.model = self.ema.model.to(dtype=dtype, device=device)
ema_dtype = dtype
def to_device(d):
for k, p in d.items():
if isinstance(d[k], dict):
to_device(d[k])
else:
d[k] = p.to(device=device)
to_device(self.ema.fp32_params)
tm = self.ema.model
with torch.no_grad():
tm.eval()
if self.cfg.ema_encoder_only:
assert target is None
ema_input = extractor_out["local_features"]
ema_input = feature_extractor.contextualized_features(
ema_input.to(dtype=ema_dtype),
padding_mask,
mask=False,
remove_masked=False,
)
ema_blocks = tm
else:
ema_blocks = tm.blocks
if feature_extractor.modality_cfg.ema_local_encoder:
inp = (
target.to(dtype=ema_dtype)
if target is not None
else source.to(dtype=ema_dtype)
)
ema_input = tm.modality_encoders[mode](
inp,
padding_mask,
mask=False,
remove_masked=False,
)
else:
assert target is None
ema_input = extractor_out["local_features"]
ema_feature_enc = tm.modality_encoders[mode]
ema_input = ema_feature_enc.contextualized_features(
ema_input.to(dtype=ema_dtype),
padding_mask,
mask=False,
remove_masked=False,
)
ema_padding_mask = ema_input["padding_mask"]
ema_alibi_bias = ema_input.get("alibi_bias", None)
ema_alibi_scale = ema_input.get("alibi_scale", None)
ema_input = ema_input["x"]
y = []
ema_x = []
extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
for i, blk in enumerate(ema_blocks):
ab = ema_alibi_bias
if ab is not None and alibi_scale is not None:
scale = (
ema_alibi_scale[i]
if ema_alibi_scale.size(0) > 1
else ema_alibi_scale.squeeze(0)
)
ab = ab * scale.type_as(ab)
ema_input, lr = blk(
ema_input,
padding_mask=ema_padding_mask,
alibi_bias=ab,
)
y.append(lr[:, extra_tokens:])
ema_x.append(ema_input[:, extra_tokens:])
y = self.make_targets(y, self.average_top_k_layers)
orig_targets = y
if self.cfg.clone_batch > 1:
y = y.repeat_interleave(self.cfg.clone_batch, 0)
masked = encoder_mask.mask.unsqueeze(-1)
masked_b = encoder_mask.mask.bool()
y = y[masked_b]
if xs[0].size(1) == masked_b.size(1):
xs = [x[masked_b] for x in xs]
else:
xs = [x.reshape(-1, x.size(-1)) for x in xs]
sample_size = masked.sum().long()
result = {
"losses": {},
"sample_size": sample_size,
}
sample_size = result["sample_size"]
if self.cfg.cls_loss > 0:
assert extra_tokens > 0
cls_target = orig_targets.mean(dim=1)
if self.cfg.clone_batch > 1:
cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
cls_pred = x[:, extra_tokens - 1]
result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
self.cfg.cls_loss * sample_size
)
if self.cfg.recon_loss > 0:
with torch.no_grad():
target = feature_extractor.patchify(source)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
if self.cfg.clone_batch > 1:
target = target.repeat_interleave(self.cfg.clone_batch, 0)
if masked_b is not None:
target = target[masked_b]
recon = xs[0]
if self.recon_proj is not None:
recon = self.recon_proj(recon)
result["losses"]["recon"] = (
self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
)
if self.cfg.d2v_loss > 0:
for i, x in enumerate(xs):
reg_loss = self.d2v_loss(x, y)
n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
result["losses"][n] = reg_loss * self.cfg.d2v_loss
suffix = "" if len(self.modalities) == 1 else f"_{mode}"
with torch.no_grad():
if encoder_mask is not None:
result["masked_pct"] = 1 - (
encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
)
for i, x in enumerate(xs):
n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
result[n] = self.compute_var(x.float())
if self.ema is not None:
for k, v in self.ema.logs.items():
result[k] = v
y = y.float()
result[f"target_var{suffix}"] = self.compute_var(y)
if self.num_updates > 5000:
if result[f"target_var{suffix}"] < self.cfg.min_target_var:
logger.error(
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
)
raise Exception(
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
)
for k in result.keys():
if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
logger.error(
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
)
raise Exception(
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
)
result["ema_decay"] = self.ema.get_decay() * 1000
return result
def forward_decoder(
self,
x,
feature_extractor,
decoder,
mask_info,
):
x = feature_extractor.decoder_input(x, mask_info)
x = decoder(*x)
return x
def d2v_loss(self, x, y):
x = x.view(-1, x.size(-1)).float()
y = y.view(-1, x.size(-1))
if self.loss_beta == 0:
loss = F.mse_loss(x, y, reduction="none")
else:
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
if self.loss_scale is not None:
scale = self.loss_scale
else:
scale = 1 / math.sqrt(x.size(-1))
reg_loss = loss * scale
return reg_loss
def make_targets(self, y, num_layers):
with torch.no_grad():
target_layer_results = y[-num_layers:]
permuted = False
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
target_layer_results = [
tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
]
permuted = True
if self.cfg.batch_norm_target_layer:
target_layer_results = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in target_layer_results
]
if self.cfg.instance_norm_target_layer:
target_layer_results = [
F.instance_norm(tl.float()) for tl in target_layer_results
]
if permuted:
target_layer_results = [
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
]
if self.cfg.layer_norm_target_layer:
target_layer_results = [
F.layer_norm(tl.float(), tl.shape[-1:])
for tl in target_layer_results
]
y = target_layer_results[0].float()
for tl in target_layer_results[1:]:
y.add_(tl.float())
y = y.div_(len(target_layer_results))
if self.cfg.layer_norm_targets:
y = F.layer_norm(y, y.shape[-1:])
if self.cfg.instance_norm_targets:
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
return y
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y**2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def extract_features(
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
):
res = self.forward(
source,
mode=mode,
padding_mask=padding_mask,
mask=mask,
features_only=True,
remove_extra_tokens=remove_extra_tokens,
)
return res
def remove_pretraining_modules(self, modality=None, keep_decoder=False):
self.ema = None
self.cfg.clone_batch = 1
self.recon_proj = None
if not keep_decoder:
self.shared_decoder = None
modality = modality.lower() if modality is not None else None
for k in list(self.modality_encoders.keys()):
if modality is not None and k.lower() != modality:
del self.modality_encoders[k]
else:
self.modality_encoders[k].remove_pretraining_modules(
keep_decoder=keep_decoder
)
if not keep_decoder:
self.modality_encoders[k].decoder = None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import II
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.data.data_utils import compute_mask_indices
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec import (
ConvFeatureExtractionModel,
Wav2Vec2Config,
TransformerEncoder,
)
from fairseq.modules import (
GradMultiply,
LayerNorm,
)
from fairseq.utils import index_put
logger = logging.getLogger(__name__)
@dataclass
class Data2VecAudioConfig(Wav2Vec2Config):
loss_beta: float = field(
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
)
loss_scale: Optional[float] = field(
default=None,
metadata={
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers: int = field(
default=8, metadata={"help": "how many layers to average"}
)
layer_norm_target_layer: bool = False
instance_norm_target_layer: bool = False
instance_norm_targets: bool = False
layer_norm_targets: bool = False
batch_norm_target_layer: bool = False
group_norm_target_layer: bool = False
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
ema_end_decay: float = field(
default=0.9999, metadata={"help": "final ema decay rate"}
)
# when to finish annealing ema decay rate
ema_anneal_end_step: int = II("optimization.max_update")
ema_transformer_only: bool = field(
default=True,
metadata={"help": "whether to momentum update only the transformer"},
)
ema_layers_only: bool = field(
default=True,
metadata={"help": "whether to momentum update only the transformer layers"},
)
max_update: int = II("optimization.max_update")
min_target_var: float = field(
default=0.1, metadata={"help": "stop training if target var falls below this"}
)
min_pred_var: float = field(
default=0.01,
metadata={"help": "stop training if prediction var falls below this"},
)
def get_annealed_rate(start, end, curr_step, total_steps):
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
class Data2VecAudioModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecAudioConfig):
super().__init__()
self.cfg = cfg
feature_enc_layers = eval(cfg.conv_feature_layers)
self.extractor_embed = feature_enc_layers[-1][0]
self.ema = None
self.embed = cfg.encoder_embed_dim
self.average_top_k_layers = cfg.average_top_k_layers
self.loss_beta = cfg.loss_beta
self.loss_scale = cfg.loss_scale
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_before = cfg.mask_channel_before
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.extractor_embed)
self.final_proj = nn.Linear(self.embed, self.embed)
self.num_updates = 0
def make_ema_teacher(self):
ema_config = EMAModuleConfig(
ema_decay=self.cfg.ema_decay,
ema_fp32=True,
)
skip_keys = set()
if self.cfg.ema_layers_only:
self.cfg.ema_transformer_only = True
for k, _ in self.encoder.pos_conv.named_parameters():
skip_keys.add(f"pos_conv.{k}")
self.ema = EMAModule(
self.encoder if self.cfg.ema_transformer_only else self,
ema_config,
skip_keys=skip_keys,
)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is None and self.final_proj is not None:
logger.info(f"making ema teacher")
self.make_ema_teacher()
elif self.training and self.ema is not None:
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay)
if self.ema.get_decay() < 1:
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
self.num_updates = num_updates
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
if self.ema is not None:
state[prefix + "_ema"] = self.ema.fp32_params
return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if self.ema is not None:
k = prefix + "_ema"
assert k in state_dict
self.ema.restore(state_dict[k], True)
del state_dict[k]
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
@classmethod
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def apply_mask(
self,
x,
padding_mask,
mask_indices=None,
mask_channel_indices=None,
):
B, T, C = x.shape
if self.mask_channel_prob > 0 and self.mask_channel_before:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
if self.mask_prob > 0:
if mask_indices is None:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=1,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
require_same_masks=self.cfg.require_same_masks,
mask_dropout=self.cfg.mask_dropout,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x = index_put(x, mask_indices, self.mask_emb)
else:
mask_indices = None
if self.mask_channel_prob > 0 and not self.mask_channel_before:
if mask_channel_indices is None:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x = index_put(x, mask_channel_indices, 0)
return x, mask_indices
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)
conv_cfg_list = eval(self.cfg.conv_feature_layers)
for i in range(len(conv_cfg_list)):
input_lengths = _conv_out_length(
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
)
return input_lengths.to(torch.long)
def forward(
self,
source,
padding_mask=None,
mask=True,
features_only=False,
layer=None,
mask_indices=None,
mask_channel_indices=None,
padding_count=None,
):
features = source
if self.feature_grad_mult > 0:
features = self.feature_extractor(features)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(features)
features = features.transpose(1, 2)
features = self.layer_norm(features)
orig_padding_mask = padding_mask
if padding_mask is not None and padding_mask.any():
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
padding_mask = torch.zeros(
features.shape[:2], dtype=features.dtype, device=features.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
output_lengths - 1,
)
] = 1
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
else:
padding_mask = None
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
pre_encoder_features = None
if self.cfg.ema_transformer_only:
pre_encoder_features = features.clone()
features = self.dropout_input(features)
if mask:
x, mask_indices = self.apply_mask(
features,
padding_mask,
mask_indices=mask_indices,
mask_channel_indices=mask_channel_indices,
)
else:
x = features
mask_indices = None
x, layer_results = self.encoder(
x,
padding_mask=padding_mask,
layer=layer,
)
if features_only:
return {
"x": x,
"padding_mask": padding_mask,
"layer_results": layer_results,
}
result = {
"losses": {},
}
with torch.no_grad():
self.ema.model.eval()
if self.cfg.ema_transformer_only:
y, layer_results = self.ema.model.extract_features(
pre_encoder_features,
padding_mask=padding_mask,
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
)
y = {
"x": y,
"padding_mask": padding_mask,
"layer_results": layer_results,
}
else:
y = self.ema.model.extract_features(
source=source,
padding_mask=orig_padding_mask,
mask=False,
)
target_layer_results = [l[2] for l in y["layer_results"]]
permuted = False
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
target_layer_results = [
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
]
permuted = True
if self.cfg.batch_norm_target_layer:
target_layer_results = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in target_layer_results
]
if self.cfg.instance_norm_target_layer:
target_layer_results = [
F.instance_norm(tl.float()) for tl in target_layer_results
]
if permuted:
target_layer_results = [
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
]
if self.cfg.group_norm_target_layer:
target_layer_results = [
F.layer_norm(tl.float(), tl.shape[-2:])
for tl in target_layer_results
]
if self.cfg.layer_norm_target_layer:
target_layer_results = [
F.layer_norm(tl.float(), tl.shape[-1:])
for tl in target_layer_results
]
y = sum(target_layer_results) / len(target_layer_results)
if self.cfg.layer_norm_targets:
y = F.layer_norm(y.float(), y.shape[-1:])
if self.cfg.instance_norm_targets:
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
if not permuted:
y = y.transpose(0, 1)
y = y[mask_indices]
x = x[mask_indices]
x = self.final_proj(x)
sz = x.size(-1)
if self.loss_beta == 0:
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
else:
loss = F.smooth_l1_loss(
x.float(), y.float(), reduction="none", beta=self.loss_beta
).sum(dim=-1)
if self.loss_scale is not None:
scale = self.loss_scale
else:
scale = 1 / math.sqrt(sz)
result["losses"]["regression"] = loss.sum() * scale
if "sample_size" not in result:
result["sample_size"] = loss.numel()
with torch.no_grad():
result["target_var"] = self.compute_var(y)
result["pred_var"] = self.compute_var(x.float())
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
logger.error(
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
)
raise Exception(
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
)
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
logger.error(
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
)
raise Exception(
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
)
if self.ema is not None:
result["ema_decay"] = self.ema.get_decay() * 1000
return result
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y ** 2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def extract_features(
self, source, padding_mask, mask=False, layer=None
):
res = self.forward(
source,
padding_mask,
mask=mask,
features_only=True,
layer=layer,
)
return res
def remove_pretraining_modules(self, last_layer=None):
self.final_proj = None
self.ema = None
if last_layer is not None:
self.encoder.layers = nn.ModuleList(
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from typing import Any
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
logger = logging.getLogger(__name__)
@dataclass
class Data2VecImageClassificationConfig(FairseqDataclass):
model_path: str = MISSING
no_pretrained_weights: bool = False
num_classes: int = 1000
mixup: float = 0.8
cutmix: float = 1.0
label_smoothing: float = 0.1
pretrained_model_args: Any = None
data: str = II("task.data")
@register_model(
"data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
)
class Data2VecImageClassificationModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecImageClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
pretrained_args.task.data = cfg.data
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
self.model = model
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
self.head.weight.data.mul_(1e-3)
self.head.bias.data.mul_(1e-3)
self.mixup_fn = None
if cfg.mixup > 0 or cfg.cutmix > 0:
from timm.data import Mixup
self.mixup_fn = Mixup(
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode="batch",
label_smoothing=cfg.label_smoothing,
num_classes=cfg.num_classes,
)
def load_model_weights(self, state, model, cfg):
if "_ema" in state["model"]:
del state["model"]["_ema"]
model.load_state_dict(state["model"], strict=True)
@classmethod
def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward(
self,
img,
label=None,
):
if self.training and self.mixup_fn is not None and label is not None:
img, label = self.mixup_fn(img, label)
x = self.model(img, mask=False)
x = x[:, 1:]
x = self.fc_norm(x.mean(1))
x = self.head(x)
if label is None:
return x
if self.training and self.mixup_fn is not None:
loss = -label * F.log_softmax(x.float(), dim=-1)
else:
loss = F.cross_entropy(
x.float(),
label,
label_smoothing=self.cfg.label_smoothing if self.training else 0,
reduction="none",
)
result = {
"losses": {"regression": loss},
"sample_size": img.size(0),
}
if not self.training:
with torch.no_grad():
pred = x.argmax(-1)
correct = (pred == label).sum()
result["correct"] = correct
return result
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
)
from fairseq.models.roberta.model import RobertaLMHead, RobertaClassificationHead
from fairseq.models.transformer import TransformerEncoder, TransformerConfig
from fairseq.modules.transformer_sentence_encoder import init_bert_params
logger = logging.getLogger(__name__)
@dataclass
class Data2VecTextConfig(FairseqDataclass):
max_positions: int = II("task.tokens_per_sample")
head_layers: int = 1
transformer: TransformerConfig = TransformerConfig()
load_checkpoint_heads: bool = field(
default=False,
metadata={"help": "(re-)register and load heads when loading checkpoints"},
)
loss_beta: float = field(
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
)
loss_scale: Optional[float] = field(
default=None,
metadata={
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers: int = field(
default=8, metadata={"help": "how many layers to average"}
)
layer_norm_target_layer: bool = False
instance_norm_target_layer: bool = False
batch_norm_target_layer: bool = False
instance_norm_targets: bool = False
layer_norm_targets: bool = False
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
ema_end_decay: float = field(
default=0.9999, metadata={"help": "final ema decay rate"}
)
# when to finish annealing ema decay rate
ema_anneal_end_step: int = II("optimization.max_update")
ema_transformer_layers_only: bool = field(
default=True,
metadata={"help": "whether to momentum update only the transformer layers"},
)
def get_annealed_rate(start, end, curr_step, total_steps):
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
@register_model("data2vec_text", dataclass=Data2VecTextConfig)
class Data2VecTextModel(FairseqEncoderModel):
def __init__(self, cfg: Data2VecTextConfig, encoder):
super().__init__(encoder)
self.cfg = cfg
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
@classmethod
def build_model(cls, cfg, task):
"""Build a new model instance."""
encoder = Data2VecTextEncoder(cfg, task.source_dictionary, task.cfg.data)
return cls(cfg, encoder)
def forward(
self,
src_tokens,
target_tokens=None,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs,
):
if classification_head_name is not None:
features_only = True
res = self.encoder(
src_tokens, target_tokens, features_only, return_all_hiddens, **kwargs
)
if isinstance(res, tuple):
x, extra = res
else:
return res
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = RobertaClassificationHead(
input_dim=self.cfg.transformer.encoder.embed_dim,
inner_dim=inner_dim or self.cfg.transformer.encoder.embed_dim,
num_classes=num_classes,
activation_fn="tanh",
pooler_dropout=0,
)
@property
def supported_targets(self):
return {"self"}
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
if k.startswith(prefix + "decoder"):
new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
state_dict[new_k] = state_dict[k]
del state_dict[k]
# rename emb_layer_norm -> layernorm_embedding
for k in list(state_dict.keys()):
if ".emb_layer_norm." in k:
new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.")
state_dict[new_k] = state_dict[k]
del state_dict[k]
if self.encoder.regression_head is not None:
if ".lm_head." in k:
new_k = k.replace(".lm_head.", ".regression_head.")
state_dict[new_k] = state_dict[k]
del state_dict[k]
else:
if ".regression_head." in k:
del state_dict[k]
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict.
current_head_names = (
[]
if not hasattr(self, "classification_heads")
or self.classification_heads is None
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + "classification_heads."):
continue
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
num_classes = state_dict[
prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if self.cfg.load_checkpoint_heads:
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes
!= self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
):
logger.warning(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}".format(
head_name, k
)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if (
hasattr(self, "classification_heads")
and self.classification_heads is not None
and len(self.classification_heads) > 0
):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
for k in list(state_dict.keys()):
if k.startswith(prefix + "encoder.lm_head.") or k.startswith(
prefix + "encoder.emb_head."
):
del state_dict[k]
self.encoder.lm_head = None
if self.encoder.target_model is None:
for k in list(state_dict.keys()):
if k.startswith(prefix + "encoder.target_model."):
del state_dict[k]
if (self.encoder.ema is None) and (prefix + "encoder._ema" in state_dict):
del state_dict[prefix + "encoder._ema"]
def remove_pretraining_modules(self, last_layer=None):
self.encoder.lm_head = None
self.encoder.regression_head = None
self.encoder.ema = None
self.classification_heads = None
if last_layer is not None:
self.encoder.sentence_encoder.layers = nn.ModuleList(
l
for i, l in enumerate(self.encoder.sentence_encoder.layers)
if i <= last_layer
)
self.encoder.sentence_encoder.layer_norm = None
class Data2VecTextEncoder(FairseqEncoder):
def __init__(self, cfg: Data2VecTextConfig, dictionary, task_data):
super().__init__(dictionary)
self.cfg = cfg
embed_tokens = self.build_embedding(
len(dictionary), cfg.transformer.encoder.embed_dim, dictionary.pad()
)
self.sentence_encoder = self.build_encoder(cfg, dictionary, embed_tokens)
self.mask_idx = dictionary.index("<mask>")
assert self.mask_idx != dictionary.unk(), dictionary.symbols
self.ema = None
self.average_top_k_layers = cfg.average_top_k_layers
self.loss_scale = cfg.loss_scale
assert self.cfg.head_layers >= 1
embed_dim = cfg.transformer.encoder.embed_dim
curr_dim = embed_dim
projs = []
for i in range(self.cfg.head_layers - 1):
next_dim = embed_dim * 2 if i == 0 else curr_dim
projs.append(nn.Linear(curr_dim, next_dim))
projs.append(nn.GELU())
curr_dim = next_dim
projs.append(nn.Linear(curr_dim, embed_dim))
self.regression_head = nn.Sequential(*projs)
self.num_updates = 0
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
return nn.Embedding(vocab_size, embedding_dim, padding_idx)
def build_encoder(self, cfg, dictionary, embed_tokens):
encoder = TransformerEncoder(cfg.transformer, dictionary, embed_tokens, return_fc=True)
encoder.apply(init_bert_params)
return encoder
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
return RobertaLMHead(embed_dim, output_dim, activation_fn, weight)
def make_ema_teacher(self):
ema_config = EMAModuleConfig(
ema_decay=self.cfg.ema_decay,
ema_fp32=True,
)
skip_keys = set()
if self.cfg.ema_transformer_layers_only:
for k, _ in self.sentence_encoder.embed_positions.named_parameters():
skip_keys.add(f"embed_tokens.{k}")
for k, _ in self.sentence_encoder.embed_positions.named_parameters():
skip_keys.add(f"embed_positions.{k}")
if self.sentence_encoder.layernorm_embedding is not None:
for (
k,
_,
) in self.sentence_encoder.layernorm_embedding.named_parameters():
skip_keys.add(f"layernorm_embedding.{k}")
if self.sentence_encoder.layer_norm is not None:
for k, _ in self.sentence_encoder.layer_norm.named_parameters():
skip_keys.add(f"layernorm_embedding.{k}")
self.ema = EMAModule(
self.sentence_encoder,
ema_config,
skip_keys=skip_keys,
)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is None and self.regression_head is not None:
logger.info(f"making ema teacher")
self.make_ema_teacher()
elif self.training and self.ema is not None:
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay)
if self.ema.get_decay() < 1:
self.ema.step(self.sentence_encoder)
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
if self.ema is not None:
state[prefix + "_ema"] = self.ema.fp32_params
return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if self.ema is not None:
k = prefix + "_ema"
assert k in state_dict
self.ema.restore(state_dict[k], True)
del state_dict[k]
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def forward(
self,
src_tokens,
target_tokens=None,
features_only=False,
return_all_hiddens=False,
masked_tokens=None,
**unused,
):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x, extra = self.extract_features(
src_tokens, return_all_hiddens=return_all_hiddens
)
if features_only:
return x, extra
assert target_tokens is not None
with torch.no_grad():
# use EMA parameter as the teacher
self.ema.model.eval()
encoder_out = self.ema.model(
target_tokens,
return_all_hiddens=True,
)
y = encoder_out["fc_results"]
y = y[-self.average_top_k_layers :]
permuted = False
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
y = [tl.permute(1, 2, 0) for tl in y] # TBC -> BCT
permuted = True
if self.cfg.batch_norm_target_layer:
y = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in y
]
if self.cfg.instance_norm_target_layer:
y = [F.instance_norm(tl.float()) for tl in y]
if permuted:
y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
if self.cfg.layer_norm_target_layer:
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
y = sum(y) / len(y)
if not permuted:
y = y.transpose(0, 1)
if self.cfg.layer_norm_targets:
y = F.layer_norm(y.float(), y.shape[-1:])
if self.cfg.instance_norm_targets:
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
masked_indices = src_tokens.eq(self.mask_idx)
x = x[masked_indices]
y = y[masked_indices]
x = self.regression_head(x)
sz = x.size(-1)
if self.cfg.loss_beta == 0:
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
else:
loss = F.smooth_l1_loss(
x.float(), y.float(), reduction="none", beta=self.cfg.loss_beta
).sum(dim=-1)
result = {
"losses": {
"main": loss.sum() / math.sqrt(sz)
if self.loss_scale <= 0
else loss.sum() * self.loss_scale,
},
"sample_size": loss.numel(),
}
# logging other values
other_logs = {
"ema_decay": self.ema.get_decay() * 1000
}
result["logs"] = other_logs
return result
def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
encoder_out = self.sentence_encoder(
src_tokens,
return_all_hiddens=return_all_hiddens,
token_embeddings=kwargs.get("token_embeddings", None),
)
# T x B x C -> B x T x C
features = encoder_out["encoder_out"][0].transpose(0, 1)
inner_states = encoder_out["encoder_states"] if return_all_hiddens else None
return features, {
"inner_states": inner_states,
"encoder_embedding": encoder_out["encoder_embedding"][0],
}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.cfg.max_positions
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from typing import Any
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.roberta.model import RobertaClassificationHead
from examples.data2vec.data.modality import Modality
logger = logging.getLogger(__name__)
@dataclass
class Data2VecTextClassificationConfig(FairseqDataclass):
pooler_dropout: float = 0.0
pooler_activation_fn: str = "tanh"
quant_noise_pq: int = 0
quant_noise_pq_block_size: int = 8
spectral_norm_classification_head: bool = False
model_path: str = MISSING
no_pretrained_weights: bool = False
pretrained_model_args: Any = None
@register_model(
"data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
)
class Data2VecTextClassificationModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecTextClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
self.model = model
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
self.classification_heads = nn.ModuleDict()
def load_model_weights(self, state, model, cfg):
for k in list(state["model"].keys()):
if (
k.startswith("shared_decoder") or
k.startswith("_ema") or
"decoder" in k
):
logger.info(f"Deleting {k} from checkpoint")
del state["model"][k]
model.load_state_dict(state["model"], strict=True)
@classmethod
def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
embed_dim = self.cfg.pretrained_model_args.model.embed_dim
self.classification_heads[name] = RobertaClassificationHead(
input_dim=embed_dim,
inner_dim=inner_dim or embed_dim,
num_classes=num_classes,
activation_fn=self.cfg.pooler_activation_fn,
pooler_dropout=self.cfg.pooler_dropout,
q_noise=self.cfg.quant_noise_pq,
qn_block_size=self.cfg.quant_noise_pq_block_size,
do_spectral_norm=self.cfg.spectral_norm_classification_head,
)
def forward(
self,
source,
id,
padding_mask,
features_only=True,
remove_extra_tokens=True,
classification_head_name=None,
):
encoder_out = self.model(
source,
id=id,
mode=Modality.TEXT,
padding_mask=padding_mask,
mask=False,
features_only=features_only,
remove_extra_tokens=remove_extra_tokens
)
logits = self.classification_heads[classification_head_name](encoder_out["x"])
return logits, encoder_out
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
import math
import numpy as np
import random
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import II
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
logger = logging.getLogger(__name__)
@dataclass
class Data2VecVisionConfig(FairseqDataclass):
layer_scale_init_value: float = field(
default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
)
num_mask_patches: int = field(
default=75,
metadata={"help": "number of the visual tokens/patches need be masked"},
)
min_mask_patches_per_block: int = 16
max_mask_patches_per_block: int = 196
image_size: int = 224
patch_size: int = 16
in_channels: int = 3
shared_rel_pos_bias: bool = True
drop_path: float = 0.1
attention_dropout: float = 0.0
depth: int = 12
embed_dim: int = 768
num_heads: int = 12
mlp_ratio: int = 4
loss_beta: float = field(
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
)
loss_scale: Optional[float] = field(
default=None,
metadata={
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers: int = field(
default=8, metadata={"help": "how many layers to average"}
)
end_of_block_targets: bool = True
layer_norm_target_layer: bool = False
instance_norm_target_layer: bool = False
batch_norm_target_layer: bool = False
instance_norm_targets: bool = False
layer_norm_targets: bool = False
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
ema_end_decay: float = field(
default=0.9999, metadata={"help": "final ema decay rate"}
)
# when to finish annealing ema decay rate
ema_anneal_end_step: int = II("optimization.max_update")
ema_transformer_only: bool = field(
default=True,
metadata={"help": "whether to momentum update only the transformer layers"},
)
def get_annealed_rate(start, end, curr_step, total_steps):
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
@register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
class Data2VecVisionModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecVisionConfig):
super().__init__()
self.cfg = cfg
self.ema = None
self.average_top_k_layers = cfg.average_top_k_layers
self.loss_beta = cfg.loss_beta
self.loss_scale = (
cfg.loss_scale
if cfg.loss_scale is not None
else 1 / math.sqrt(cfg.embed_dim)
)
self.patch_embed = PatchEmbed(
img_size=cfg.image_size,
patch_size=cfg.patch_size,
in_chans=cfg.in_channels,
embed_dim=cfg.embed_dim,
)
patch_size = self.patch_embed.patch_size
self.window_size = (
cfg.image_size // patch_size[0],
cfg.image_size // patch_size[1],
)
self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
nn.init.trunc_normal_(self.cls_emb, 0.02)
nn.init.trunc_normal_(self.mask_emb, 0.02)
self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
self.num_updates = 0
def make_ema_teacher(self):
ema_config = EMAModuleConfig(
ema_decay=self.cfg.ema_decay,
ema_fp32=True,
)
self.ema = EMAModule(
self.encoder if self.cfg.ema_transformer_only else self,
ema_config,
)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is None and self.final_proj is not None:
logger.info(f"making ema teacher")
self.make_ema_teacher()
elif self.training and self.ema is not None:
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay)
if self.ema.get_decay() < 1:
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
self.num_updates = num_updates
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
if self.ema is not None:
state[prefix + "_ema"] = self.ema.fp32_params
return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if self.ema is not None:
k = prefix + "_ema"
assert k in state_dict
self.ema.restore(state_dict[k], True)
del state_dict[k]
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
@classmethod
def build_model(cls, cfg: Data2VecVisionConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def make_mask(self, bsz, num_masks, min_masks, max_masks):
height, width = self.window_size
masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
for i in range(bsz):
mask = masks[i]
mask_count = 0
min_aspect = 0.3
max_aspect = 1 / min_aspect
log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def _mask(mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(min_masks, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < width and h < height:
top = random.randint(0, height - h)
left = random.randint(0, width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
while mask_count < num_masks:
max_mask_patches = min(num_masks - mask_count, max_masks)
delta = _mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return torch.from_numpy(masks)
def forward(
self,
img,
mask: bool = True,
layer_results: bool = False,
):
x = self.patch_embed(img)
batch_size, seq_len, _ = x.size()
if mask:
mask_indices = self.make_mask(
img.size(0),
self.cfg.num_mask_patches,
self.cfg.min_mask_patches_per_block,
self.cfg.max_mask_patches_per_block,
)
bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
else:
mask_indices = bool_mask = None
cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.ema is not None:
with torch.no_grad():
self.ema.model.eval()
if self.cfg.ema_transformer_only:
y = self.ema.model(
x,
layer_results="end" if self.cfg.end_of_block_targets else "fc",
)
else:
y = self.ema.model(
img,
mask=False,
layer_results=True,
)
y = y[-self.cfg.average_top_k_layers :]
permuted = False
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
permuted = True
if self.cfg.batch_norm_target_layer:
y = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in y
]
if self.cfg.instance_norm_target_layer:
y = [F.instance_norm(tl.float()) for tl in y]
if permuted:
y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
if self.cfg.layer_norm_target_layer:
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
y = sum(y) / len(y)
if self.cfg.layer_norm_targets:
y = F.layer_norm(y.float(), y.shape[-1:])
if self.cfg.instance_norm_targets:
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
y = y[bool_mask].float()
if mask_indices is not None:
mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
if layer_results:
enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
else:
enc_layer_results = None
x = self.encoder(x, layer_results=enc_layer_results)
if layer_results or mask_indices is None:
return x
x = x[bool_mask].float()
if self.loss_beta == 0:
loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
else:
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
dim=-1
)
if self.loss_scale > 0:
loss = loss * self.loss_scale
result = {
"losses": {"regression": loss.sum()},
"sample_size": loss.numel(),
"target_var": self.compute_var(y),
"pred_var": self.compute_var(x),
"ema_decay": self.ema.get_decay() * 1000,
}
return result
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y ** 2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def remove_pretraining_modules(self, last_layer=None):
self.final_proj = None
self.ema = None
self.encoder.norm = nn.Identity()
self.mask_emb = None
if last_layer is not None:
self.encoder.layers = nn.ModuleList(
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
)
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
if isinstance(img_size, int):
img_size = img_size, img_size
if isinstance(patch_size, int):
patch_size = patch_size, patch_size
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.conv = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# BCHW -> BTC
x = self.conv(x).flatten(2).transpose(1, 2)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
assert 1==2
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
print("attn.size() :", attn.size())
print("rel_pos_bias.size() :", rel_pos_bias.size())
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
print("self.window_size :", self.window_size)
print("self.num_relative_distance :", self.num_relative_distance)
print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
window_size=None,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop),
)
if init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
print("inside block :", x.size())
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
fc_feature = self.drop_path(self.mlp(self.norm2(x)))
x = x + fc_feature
else:
x = x + self.drop_path(
self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
x = x + fc_feature
return x, fc_feature
class TransformerEncoder(nn.Module):
def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
super().__init__()
self.rel_pos_bias = None
if cfg.shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=patch_shape, num_heads=cfg.num_heads
)
dpr = [
x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
] # stochastic depth decay rule
print("TransformerEncoder > patch_shape :", patch_shape)
self.blocks = nn.ModuleList(
Block(
dim=cfg.embed_dim,
num_heads=cfg.num_heads,
attn_drop=cfg.attention_dropout,
drop_path=dpr[i],
init_values=cfg.layer_scale_init_value,
window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
)
for i in range(cfg.depth)
)
self.norm = nn.LayerNorm(cfg.embed_dim)
self.apply(self.init_weights)
self.fix_init_weight()
def init_weights(self, m):
std = 0.02
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp[2].weight.data, layer_id + 1)
def extract_features(self, x, layer_results):
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
z = []
for i, blk in enumerate(self.blocks):
x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
if layer_results == "end":
z.append(x)
elif layer_results == "fc":
z.append(fc_feature)
return z if layer_results else self.norm(x)
def forward(self, x, layer_results=None):
x = self.extract_features(x, layer_results=layer_results)
if layer_results:
return [z[:, 1:] for z in x]
x = x[:, 1:]
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from functools import partial
from timm.models.vision_transformer import PatchEmbed, Block
import torch
import torch.nn as nn
import numpy as np
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
try:
from apex.normalization import FusedLayerNorm
except:
FusedLayerNorm = nn.LayerNorm
import torch.nn.functional as F
logger = logging.getLogger(__name__)
@dataclass
class MaeConfig(FairseqDataclass):
input_size: int = 224
in_chans: int = 3
patch_size: int = 16
embed_dim: int = 768
depth: int = 12
num_heads: int = 12
decoder_embed_dim: int = 512
decoder_depth: int = 8
decoder_num_heads: int = 16
mlp_ratio: int = 4
norm_eps: float = 1e-6
drop_path_rate: float = 0.0
mask_ratio: float = 0.75
norm_pix_loss: bool = True
w2v_block: bool = False
alt_block: bool = False
alt_block2: bool = False
alt_attention: bool = False
block_dropout: float = 0
attention_dropout: float = 0
activation_dropout: float = 0
layer_norm_first: bool = False
fused_ln: bool = True
end_of_block_targets: bool = True
no_decoder_embed: bool = False
no_decoder_pos_embed: bool = False
mask_noise_std: float = 0
single_qkv: bool = False
use_rel_pos_bias: bool = False
no_cls: bool = False
def modify_relative_position_bias(orig_bias, bsz, mask):
if mask is None:
return orig_bias.unsqueeze(0).repeat(
bsz, 1, 1, 1
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
mask_for_rel_pos_bias = torch.cat(
(torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
).bool() # bsz x seqlen (add CLS token)
unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
1, heads, 1
) # bsz x seq_len => bsz x heads x seq_len
b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
bsz, 1, 1, 1
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
unmasked_for_rel_pos_bias.unsqueeze(-1)
)
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
new_len = b_t_t_rel_pos_bias.size(-2)
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
unmasked_for_rel_pos_bias.unsqueeze(-2)
)
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
return b_t_t_rel_pos_bias
class AltBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
layer_norm_first=True,
ffn_targets=False,
use_rel_pos_bias=False,
window_size=None,
alt_attention=False,
):
super().__init__()
self.layer_norm_first = layer_norm_first
self.ffn_targets = ffn_targets
from timm.models.vision_transformer import Attention, DropPath, Mlp
self.norm1 = norm_layer(dim)
self.use_rel_pos_bias = use_rel_pos_bias
if use_rel_pos_bias:
self.attn = AltAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
)
else:
if alt_attention:
from .multi.modules import AltAttention as AltAttention2
self.attn = AltAttention2(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
else:
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x, rel_pos_bias=None, pos_mask=None):
if self.layer_norm_first:
if self.use_rel_pos_bias:
x = x + self.drop_path(
self.attn(
self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
)
)
else:
x = x + self.drop_path(self.attn(self.norm1(x)))
t = self.mlp(self.norm2(x))
x = x + self.drop_path(t)
if not self.ffn_targets:
t = x
return x, t
else:
if self.use_rel_pos_bias:
x = x + self.drop_path(
self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
)
else:
x = x + self.drop_path(self.attn(x))
r = x = self.norm1(x)
x = self.mlp(x)
t = x
x = self.norm2(r + self.drop_path(x))
if not self.ffn_targets:
t = x
return x, t
class AltAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, pos_mask=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + modify_relative_position_bias(
relative_position_bias, x.size(0), pos_mask
)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
@register_model("mae", dataclass=MaeConfig)
class MaeModel(BaseFairseqModel):
def __init__(self, cfg: MaeConfig):
super().__init__()
self.cfg = cfg
self.mask_ratio = cfg.mask_ratio
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(
cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
) # fixed sin-cos embedding
norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
dpr = [
x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
] # stochastic depth decay rule
def make_block(drop_path):
if cfg.w2v_block:
return TransformerSentenceEncoderLayer(
embedding_dim=cfg.embed_dim,
ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
num_attention_heads=cfg.num_heads,
dropout=cfg.block_dropout,
attention_dropout=cfg.attention_dropout,
activation_dropout=cfg.activation_dropout,
activation_fn="gelu",
layer_norm_first=cfg.layer_norm_first,
drop_path=drop_path,
norm_eps=1e-6,
single_qkv=cfg.single_qkv,
fused_ln=cfg.fused_ln,
)
elif cfg.alt_block:
window_size = (
cfg.input_size // self.patch_embed.patch_size[0],
cfg.input_size // self.patch_embed.patch_size[1],
)
return AltBlock(
cfg.embed_dim,
cfg.num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
drop_path=drop_path,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
use_rel_pos_bias=cfg.use_rel_pos_bias,
window_size=window_size
if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
else None,
alt_attention=cfg.alt_attention,
)
elif cfg.alt_block2:
from .multi.modules import AltBlock as AltBlock2
return AltBlock2(
cfg.embed_dim,
cfg.num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
drop_path=drop_path,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
)
else:
return Block(
cfg.embed_dim,
cfg.num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
drop_path=drop_path,
)
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
self.norm = norm_layer(cfg.embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = (
nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
if not cfg.no_decoder_embed
else None
)
self.mask_token = (
nn.Parameter(
torch.zeros(
1,
1,
cfg.decoder_embed_dim
if not cfg.no_decoder_embed
else cfg.embed_dim,
)
)
if cfg.mask_noise_std <= 0
else None
)
self.decoder_pos_embed = (
nn.Parameter(
torch.zeros(
1,
num_patches + 1,
cfg.decoder_embed_dim
if not cfg.no_decoder_embed
else cfg.embed_dim,
),
requires_grad=False,
)
if not cfg.no_decoder_pos_embed
else None
)
self.decoder_blocks = nn.ModuleList(
[
Block(
cfg.decoder_embed_dim,
cfg.decoder_num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
)
for _ in range(cfg.decoder_depth)
]
)
self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
self.decoder_pred = nn.Linear(
cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = cfg.norm_pix_loss
self.initialize_weights()
for pn, p in self.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.param_group = "no_decay"
else:
p.param_group = "with_decay"
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches ** 0.5),
cls_token=not self.cfg.no_cls,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
if self.decoder_pos_embed is not None:
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches ** 0.5),
cls_token=not self.cfg.no_cls,
)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
if self.cls_token is not None:
torch.nn.init.normal_(self.cls_token, std=0.02)
if self.mask_token is not None:
torch.nn.init.normal_(self.mask_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1
) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore # x_masked is actually unmasked x
@classmethod
def build_model(cls, cfg: MaeConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
# if self.cls_token is not None:
# x = x + self.pos_embed
# else:
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
if mask_ratio > 0:
x, mask, ids_restore = self.random_masking(x, mask_ratio)
else:
mask = ids_restore = None
# append cls token
if self.cls_token is not None:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
if self.norm is not None:
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
)
if self.cls_token is not None:
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
else:
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
x_ = torch.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
) # unshuffle
if self.cls_token is not None:
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
if self.cls_token is not None:
# remove cls token
x = x[:, 1:, :]
return x
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum()
return loss, mask.sum()
def forward(self, imgs, predictions_only=False):
latent, mask, ids_restore = self.forward_encoder(
imgs, self.mask_ratio if not predictions_only else 0
)
if predictions_only:
return latent
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss, sample_size = self.forward_loss(imgs, pred, mask)
result = {
"losses": {"regression": loss},
"sample_size": sample_size,
}
return result
def remove_pretraining_modules(self):
self.decoder_embed = None
self.decoder_blocks = None
self.decoder_norm = None
self.decoder_pos_embed = None
self.decoder_pred = None
self.mask_token = None
if self.cfg.layer_norm_first:
self.norm = None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Optional
import numpy as np
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from omegaconf import open_dict
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from .mae import interpolate_pos_embed
logger = logging.getLogger(__name__)
class PredictionMode(Enum):
MEAN_POOLING = auto()
CLS_TOKEN = auto()
LIN_SOFTMAX = auto()
@dataclass
class MaeImageClassificationConfig(FairseqDataclass):
model_path: str = MISSING
no_pretrained_weights: bool = False
linear_classifier: bool = False
num_classes: int = 1000
mixup: float = 0.8
cutmix: float = 1.0
label_smoothing: float = 0.1
drop_path_rate: float = 0.1
layer_decay: float = 0.65
mixup_prob: float = 1.0
mixup_switch_prob: float = 0.5
mixup_mode: str = "batch"
pretrained_model_args: Any = None
data: str = II("task.data")
norm_eps: Optional[float] = None
remove_alibi: bool = False
# regularization overwrites
encoder_dropout: float = 0
post_mlp_drop: float = 0
attention_dropout: float = 0
activation_dropout: float = 0.0
dropout_input: float = 0.0
layerdrop: float = 0.0
prenet_layerdrop: float = 0
prenet_dropout: float = 0
use_fc_norm: bool = True
prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
no_decay_blocks: bool = True
def get_layer_id_for_vit(name, num_layers):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if name in ["cls_token", "pos_embed"]:
return 0
elif name.startswith("patch_embed"):
return 0
elif name.startswith("rel_pos_bias"):
return num_layers - 1
elif name.startswith("blocks"):
return int(name.split(".")[1]) + 1
else:
return num_layers
@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
class MaeImageClassificationModel(BaseFairseqModel):
def __init__(self, cfg: MaeImageClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
logger.info(pretrained_args.model)
with open_dict(pretrained_args.model):
pretrained_args.model.drop_path_rate = cfg.drop_path_rate
if cfg.norm_eps is not None:
pretrained_args.model.norm_eps = cfg.norm_eps
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
if "data" in pretrained_args.task:
pretrained_args.task.data = cfg.data
elif "image" in pretrained_args.task:
pretrained_args.task.image.data = cfg.data
if "modalities" in pretrained_args.model:
prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
model_blocks = pretrained_args.model["depth"]
with open_dict(pretrained_args):
dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
pretrained_args.model["modalities"]["image"][
"start_drop_path_rate"
] = dpr[0]
pretrained_args.model["modalities"]["image"][
"end_drop_path_rate"
] = max(0, dpr[prenet_blocks - 1])
pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
pretrained_args.model["end_drop_path_rate"] = dpr[-1]
if "mae_masking" in pretrained_args.model["modalities"]["image"]:
del pretrained_args.model["modalities"]["image"]["mae_masking"]
if cfg.remove_alibi:
pretrained_args.model["modalities"]["image"][
"use_alibi_encoder"
] = False
if (
state is not None
and "modality_encoders.IMAGE.alibi_bias" in state["model"]
):
del state["model"]["modality_encoders.IMAGE.alibi_bias"]
pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
pretrained_args.model["dropout_input"] = cfg.dropout_input
pretrained_args.model["layerdrop"] = cfg.layerdrop
pretrained_args.model["modalities"]["image"][
"prenet_layerdrop"
] = cfg.prenet_layerdrop
pretrained_args.model["modalities"]["image"][
"prenet_dropout"
] = cfg.prenet_dropout
else:
# not d2v multi
with open_dict(pretrained_args):
pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
pretrained_args.model["block_dropout"] = cfg.encoder_dropout
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
self.linear_classifier = cfg.linear_classifier
self.model = model
if state is not None and not cfg.no_pretrained_weights:
interpolate_pos_embed(model, state)
if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
state["model"][
"modality_encoders.IMAGE.positional_encoder.positions"
] = state["model"][
"modality_encoders.IMAGE.positional_encoder.pos_embed"
]
del state["model"][
"modality_encoders.IMAGE.positional_encoder.pos_embed"
]
if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
del state["model"]["modality_encoders.IMAGE.encoder_mask"]
model.load_state_dict(state["model"], strict=True)
if self.d2v_multi:
model.remove_pretraining_modules(modality="image")
else:
model.remove_pretraining_modules()
if self.linear_classifier:
model.requires_grad_(False)
self.fc_norm = None
if self.cfg.use_fc_norm:
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
nn.init.constant_(self.fc_norm.bias, 0)
nn.init.constant_(self.fc_norm.weight, 1.0)
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)
self.mixup_fn = None
if cfg.mixup > 0 or cfg.cutmix > 0:
from timm.data import Mixup
self.mixup_fn = Mixup(
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=cfg.mixup_prob,
switch_prob=cfg.mixup_switch_prob,
mode=cfg.mixup_mode,
label_smoothing=cfg.label_smoothing,
num_classes=cfg.num_classes,
)
if self.model.norm is not None:
for pn, p in self.model.norm.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
if self.fc_norm is not None:
for pn, p in self.fc_norm.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
for pn, p in self.head.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
if self.d2v_multi:
mod_encs = list(model.modality_encoders.values())
assert len(mod_encs) == 1, len(mod_encs)
blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
else:
blocks = model.blocks
num_layers = len(blocks) + 1
layer_scales = list(
cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
)
if self.d2v_multi:
for n, p in self.model.named_parameters():
optimizer_override_dict = {}
if len(p.shape) == 1 or n.endswith(".bias"):
optimizer_override_dict["weight_decay_scale"] = 0
p.optim_overrides = {"optimizer": optimizer_override_dict}
if cfg.layer_decay > 0:
for i, b in enumerate(blocks):
lid = i + 1
if layer_scales[lid] == 1.0:
continue
for n, p in b.named_parameters():
optim_override = getattr(p, "optim_overrides", {})
if "optimizer" not in optim_override:
optim_override["optimizer"] = {}
if cfg.no_decay_blocks:
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
p.optim_overrides = optim_override
else:
optim_override["optimizer"] = {
"lr_scale": layer_scales[lid]
}
p.optim_overrides = optim_override
else:
for n, p in self.model.named_parameters():
optimizer_override_dict = {}
layer_id = get_layer_id_for_vit(n, num_layers)
if len(p.shape) == 1 or n.endswith(".bias"):
optimizer_override_dict["weight_decay_scale"] = 0
if cfg.layer_decay > 0:
optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
p.optim_overrides = {"optimizer": optimizer_override_dict}
@classmethod
def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward(
self,
imgs,
labels=None,
):
if self.training and self.mixup_fn is not None and labels is not None:
imgs, labels = self.mixup_fn(imgs, labels)
if self.linear_classifier:
with torch.no_grad():
x = self.model_forward(imgs)
else:
x = self.model_forward(imgs)
if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
x = x.mean(dim=1)
elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
x = x[:, 0]
elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
dtype = x.dtype
x = F.logsigmoid(x.float())
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
x = x.clamp(max=0)
x = x - torch.log(-(torch.expm1(x)))
x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
x = x.to(dtype=dtype)
else:
raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
if self.fc_norm is not None:
x = self.fc_norm(x)
x = self.head(x)
if labels is None:
return x
if self.training and self.mixup_fn is not None:
loss = -labels * F.log_softmax(x.float(), dim=-1)
else:
loss = F.cross_entropy(
x.float(),
labels,
label_smoothing=self.cfg.label_smoothing if self.training else 0,
reduction="none",
)
result = {
"losses": {"regression": loss},
"sample_size": imgs.size(0),
}
if not self.training:
with torch.no_grad():
pred = x.argmax(-1)
correct = (pred == labels).sum()
result["correct"] = correct
return result
def model_forward(self, imgs):
if self.d2v_multi:
x = self.model.extract_features(
imgs,
mode="IMAGE",
mask=False,
remove_extra_tokens=(
self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
),
)["x"]
else:
x = self.model(imgs, predictions_only=True)
if (
"no_cls" not in self.model.cfg or not self.model.cfg.no_cls
) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
x = x[:, 1:]
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from fairseq.models.wav2vec import ConvFeatureExtractionModel
from fairseq.modules import (
LayerNorm,
SamePad,
TransposeLast,
)
from fairseq.tasks import FairseqTask
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
from .modules import BlockEncoder, Decoder1d
from examples.data2vec.data.modality import Modality
@dataclass
class D2vAudioConfig(D2vModalityConfig):
type: Modality = Modality.AUDIO
extractor_mode: str = "layer_norm"
feature_encoder_spec: str = field(
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
metadata={
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
conv_pos_width: int = field(
default=95,
metadata={"help": "number of filters for convolutional positional embeddings"},
)
conv_pos_groups: int = field(
default=16,
metadata={"help": "number of groups for convolutional positional embedding"},
)
conv_pos_depth: int = field(
default=5,
metadata={"help": "depth of positional encoder network"},
)
conv_pos_pre_ln: bool = False
class AudioEncoder(ModalitySpecificEncoder):
modality_cfg: D2vAudioConfig
def __init__(
self,
modality_cfg: D2vAudioConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
feature_embed_dim = self.feature_enc_layers[-1][0]
local_encoder = ConvFeatureExtractionModel(
conv_layers=self.feature_enc_layers,
dropout=0.0,
mode=modality_cfg.extractor_mode,
conv_bias=False,
)
project_features = nn.Sequential(
TransposeLast(),
nn.LayerNorm(feature_embed_dim),
nn.Linear(feature_embed_dim, embed_dim),
)
num_pos_layers = modality_cfg.conv_pos_depth
k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
positional_encoder = nn.Sequential(
TransposeLast(),
*[
nn.Sequential(
nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=k,
padding=k // 2,
groups=modality_cfg.conv_pos_groups,
),
SamePad(k),
TransposeLast(),
LayerNorm(embed_dim, elementwise_affine=False),
TransposeLast(),
nn.GELU(),
)
for _ in range(num_pos_layers)
],
TransposeLast(),
)
if modality_cfg.conv_pos_pre_ln:
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim) if not layer_norm_first else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout,
)
decoder = (
Decoder1d(modality_cfg.decoder, embed_dim)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=project_features,
fixed_positional_encoder=None,
relative_positional_encoder=positional_encoder,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def convert_padding_mask(self, x, padding_mask):
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)
for i in range(len(self.feature_enc_layers)):
input_lengths = _conv_out_length(
input_lengths,
self.feature_enc_layers[i][1],
self.feature_enc_layers[i][2],
)
return input_lengths.to(torch.long)
if padding_mask is not None:
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = get_feat_extract_output_lengths(input_lengths)
if padding_mask.any():
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
output_lengths - 1,
)
] = 1
padding_mask = (
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
).bool()
else:
padding_mask = torch.zeros(
x.shape[:2], dtype=torch.bool, device=x.device
)
return padding_mask
def reset_parameters(self):
super().reset_parameters()
for mod in self.project_features.children():
if isinstance(mod, nn.Linear):
mod.reset_parameters()
if self.decoder is not None:
self.decoder.reset_parameters()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
from dataclasses import dataclass
from functools import partial
from omegaconf import MISSING, II
from typing import Optional, Callable
from fairseq.data.data_utils import compute_mask_indices
from fairseq.modules import GradMultiply
from fairseq.utils import index_put
from examples.data2vec.data.modality import Modality
from .modules import D2vDecoderConfig
logger = logging.getLogger(__name__)
@dataclass
class D2vModalityConfig:
type: Modality = MISSING
prenet_depth: int = 4
prenet_layerdrop: float = 0
prenet_dropout: float = 0
start_drop_path_rate: float = 0
end_drop_path_rate: float = 0
num_extra_tokens: int = 0
init_extra_token_zero: bool = True
mask_noise_std: float = 0.01
mask_prob_min: Optional[float] = None
mask_prob: float = 0.7
inverse_mask: bool = False
mask_prob_adjust: float = 0
keep_masked_pct: float = 0
mask_length: int = 5
add_masks: bool = False
remove_masks: bool = False
mask_dropout: float = 0.0
encoder_zero_mask: bool = True
mask_channel_prob: float = 0.0
mask_channel_length: int = 64
ema_local_encoder: bool = False # used in data2vec_multi
local_grad_mult: float = 1.0
use_alibi_encoder: bool = False
alibi_scale: float = 1.0
learned_alibi: bool = False
alibi_max_pos: Optional[int] = None
learned_alibi_scale: bool = False
learned_alibi_scale_per_head: bool = False
learned_alibi_scale_per_layer: bool = False
num_alibi_heads: int = II("model.num_heads")
model_depth: int = II("model.depth")
decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
class ModalitySpecificEncoder(nn.Module):
def __init__(
self,
modality_cfg: D2vModalityConfig,
embed_dim: int,
local_encoder: nn.Module,
project_features: nn.Module,
fixed_positional_encoder: Optional[nn.Module],
relative_positional_encoder: Optional[nn.Module],
context_encoder: nn.Module,
decoder: nn.Module,
get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
):
super().__init__()
self.modality_cfg = modality_cfg
self.local_encoder = local_encoder
self.project_features = project_features
self.fixed_positional_encoder = fixed_positional_encoder
self.relative_positional_encoder = relative_positional_encoder
self.context_encoder = context_encoder
self.decoder = decoder
self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
self.local_grad_mult = self.modality_cfg.local_grad_mult
self.extra_tokens = None
if modality_cfg.num_extra_tokens > 0:
self.extra_tokens = nn.Parameter(
torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
)
if not modality_cfg.init_extra_token_zero:
nn.init.normal_(self.extra_tokens)
elif self.extra_tokens.size(1) > 1:
nn.init.normal_(self.extra_tokens[:, 1:])
self.alibi_scale = None
if self.get_alibi_bias is not None:
self.alibi_scale = nn.Parameter(
torch.full(
(
(modality_cfg.prenet_depth + modality_cfg.model_depth)
if modality_cfg.learned_alibi_scale_per_layer
else 1,
1,
self.modality_cfg.num_alibi_heads
if modality_cfg.learned_alibi_scale_per_head
else 1,
1,
1,
),
modality_cfg.alibi_scale,
dtype=torch.float,
),
requires_grad=modality_cfg.learned_alibi_scale,
)
if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
assert modality_cfg.alibi_max_pos is not None
alibi_bias = self.get_alibi_bias(
batch_size=1,
time_steps=modality_cfg.alibi_max_pos,
heads=modality_cfg.num_alibi_heads,
scale=1.0,
dtype=torch.float,
device="cpu",
)
self.alibi_bias = nn.Parameter(alibi_bias)
self.get_alibi_bias = partial(
_learned_alibi_bias, alibi_bias=self.alibi_bias
)
def upgrade_state_dict_named(self, state_dict, name):
k = f"{name}.alibi_scale"
if k in state_dict and state_dict[k].dim() == 4:
state_dict[k] = state_dict[k].unsqueeze(0)
return state_dict
def convert_padding_mask(self, x, padding_mask):
return padding_mask
def decoder_input(self, x, mask_info: MaskInfo):
inp_drop = self.modality_cfg.decoder.input_dropout
if inp_drop > 0:
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
num_extra = self.modality_cfg.num_extra_tokens
if mask_info is not None:
num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
mask_tokens = x.new_empty(
x.size(0),
num_masked,
x.size(-1),
).normal_(0, self.modality_cfg.mask_noise_std)
x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
if self.modality_cfg.decoder.add_positions_masked:
assert self.fixed_positional_encoder is not None
pos = self.fixed_positional_encoder(x, None)
x = x + (pos * mask_info.mask.unsqueeze(-1))
else:
x = x[:, num_extra:]
if self.modality_cfg.decoder.add_positions_all:
assert self.fixed_positional_encoder is not None
x = x + self.fixed_positional_encoder(x, None)
return x, mask_info
def local_features(self, features):
if self.local_grad_mult > 0:
if self.local_grad_mult == 1.0:
x = self.local_encoder(features)
else:
x = GradMultiply.apply(
self.local_encoder(features), self.local_grad_mult
)
else:
with torch.no_grad():
x = self.local_encoder(features)
x = self.project_features(x)
return x
def contextualized_features(
self,
x,
padding_mask,
mask,
remove_masked,
clone_batch: int = 1,
mask_seeds: Optional[torch.Tensor] = None,
precomputed_mask=None,
):
if padding_mask is not None:
padding_mask = self.convert_padding_mask(x, padding_mask)
local_features = x
if mask and clone_batch == 1:
local_features = local_features.clone()
orig_B, orig_T, _ = x.shape
pre_mask_B = orig_B
mask_info = None
x_pos = None
if self.fixed_positional_encoder is not None:
x = x + self.fixed_positional_encoder(x, padding_mask)
if mask:
if clone_batch > 1:
x = x.repeat_interleave(clone_batch, 0)
if mask_seeds is not None:
clone_hash = [
int(hash((mask_seeds.seed, ind)) % 1e10)
for ind in range(clone_batch - 1)
]
clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
id = mask_seeds.ids
id = id.repeat_interleave(clone_batch, 0)
id = id.view(-1, clone_batch) + clone_hash.to(id)
id = id.view(-1)
mask_seeds = MaskSeed(
seed=mask_seeds.seed, update=mask_seeds.update, ids=id
)
if padding_mask is not None:
padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
x, mask_info = self.compute_mask(
x,
padding_mask,
mask_seed=mask_seeds,
apply=self.relative_positional_encoder is not None or not remove_masked,
precomputed_mask=precomputed_mask,
)
if self.relative_positional_encoder is not None:
x_pos = self.relative_positional_encoder(x)
masked_padding_mask = padding_mask
if mask and remove_masked:
x = mask_info.x_unmasked
if x_pos is not None:
x = x + gather_unmasked(x_pos, mask_info)
if padding_mask is not None and padding_mask.any():
masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
if not masked_padding_mask.any():
masked_padding_mask = None
else:
masked_padding_mask = None
elif x_pos is not None:
x = x + x_pos
alibi_bias = None
alibi_scale = self.alibi_scale
if self.get_alibi_bias is not None:
alibi_bias = self.get_alibi_bias(
batch_size=pre_mask_B,
time_steps=orig_T,
heads=self.modality_cfg.num_alibi_heads,
dtype=torch.float32,
device=x.device,
)
if alibi_scale is not None:
alibi_scale = alibi_scale.clamp_min(0)
if alibi_scale.size(0) == 1:
alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
alibi_scale = None
if clone_batch > 1:
alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
if mask_info is not None and remove_masked:
alibi_bias = masked_alibi(alibi_bias, mask_info)
if self.extra_tokens is not None:
num = self.extra_tokens.size(1)
x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
if masked_padding_mask is not None:
# B x T
masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
if alibi_bias is not None:
# B x H x T x T
alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
x = self.context_encoder(
x,
masked_padding_mask,
alibi_bias,
alibi_scale[: self.modality_cfg.prenet_depth]
if alibi_scale is not None
else None,
)
return {
"x": x,
"local_features": local_features,
"padding_mask": masked_padding_mask,
"alibi_bias": alibi_bias,
"alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
if alibi_scale is not None and alibi_scale.size(0) > 1
else alibi_scale,
"encoder_mask": mask_info,
}
def forward(
self,
features,
padding_mask,
mask: bool,
remove_masked: bool,
clone_batch: int = 1,
mask_seeds: Optional[torch.Tensor] = None,
precomputed_mask=None,
):
x = self.local_features(features)
return self.contextualized_features(
x,
padding_mask,
mask,
remove_masked,
clone_batch,
mask_seeds,
precomputed_mask,
)
def reset_parameters(self):
pass
def compute_mask(
self,
x,
padding_mask,
mask_seed: Optional[MaskSeed],
apply,
precomputed_mask,
):
if precomputed_mask is not None:
mask = precomputed_mask
mask_info = self.make_maskinfo(x, mask)
else:
B, T, C = x.shape
cfg = self.modality_cfg
mask_prob = cfg.mask_prob
if (
cfg.mask_prob_min is not None
and cfg.mask_prob_min >= 0
and cfg.mask_prob_min < mask_prob
):
mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
if mask_prob > 0:
if cfg.mask_length == 1:
mask_info = random_masking(x, mask_prob, mask_seed)
else:
if self.modality_cfg.inverse_mask:
mask_prob = 1 - mask_prob
mask = compute_mask_indices(
(B, T),
padding_mask,
mask_prob,
cfg.mask_length,
min_masks=1,
require_same_masks=True,
mask_dropout=cfg.mask_dropout,
add_masks=cfg.add_masks,
seed=mask_seed.seed if mask_seed is not None else None,
epoch=mask_seed.update if mask_seed is not None else None,
indices=mask_seed.ids if mask_seed is not None else None,
)
mask = torch.from_numpy(mask).to(device=x.device)
if self.modality_cfg.inverse_mask:
mask = 1 - mask
mask_info = self.make_maskinfo(x, mask)
else:
mask_info = None
if apply:
x = self.apply_mask(x, mask_info)
return x, mask_info
def make_maskinfo(self, x, mask, shape=None):
if shape is None:
B, T, D = x.shape
else:
B, T, D = shape
mask = mask.to(torch.uint8)
ids_shuffle = mask.argsort(dim=1)
ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
len_keep = T - mask[0].sum()
if self.modality_cfg.keep_masked_pct > 0:
len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
ids_keep = ids_shuffle[:, :len_keep]
if shape is not None:
x_unmasked = None
else:
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
mask_info = MaskInfo(
x_unmasked=x_unmasked,
mask=mask,
ids_restore=ids_restore,
ids_keep=ids_keep,
)
return mask_info
def apply_mask(self, x, mask_info):
cfg = self.modality_cfg
B, T, C = x.shape
if mask_info is not None:
mask = mask_info.mask
if cfg.encoder_zero_mask:
x = x * (1 - mask.type_as(x).unsqueeze(-1))
else:
num_masks = mask.sum().item()
masks = x.new_empty(num_masks, x.size(-1)).normal_(
0, cfg.mask_noise_std
)
x = index_put(x, mask, masks)
if cfg.mask_channel_prob > 0:
mask_channel = compute_mask_indices(
(B, C),
None,
cfg.mask_channel_prob,
cfg.mask_channel_length,
)
mask_channel = (
torch.from_numpy(mask_channel)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x = index_put(x, mask_channel, 0)
return x
def remove_pretraining_modules(self, keep_decoder=False):
if not keep_decoder:
self.decoder = None
def get_annealed_rate(start, end, curr_step, total_steps):
if curr_step >= total_steps:
return end
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
# adapted from MAE
def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
generator = None
if mask_seed is not None:
seed = int(
hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
ids_restore = ids_shuffle.argsort(dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
return MaskInfo(
x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
)
def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
return torch.gather(
x,
dim=1,
index=mask_info.ids_keep,
)
def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
return torch.gather(
x,
dim=1,
index=mask_info.ids_keep[..., 0], # ignore the feature dimension
)
def get_alibi(
max_positions: int,
attention_heads: int,
dims: int = 1,
distance: str = "manhattan",
):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
# In the paper, we only train models that have 2^a heads for some
# a. This function has some good properties that only occur when
# the input is a power of 2. To maintain that even when the number
# of heads is not a power of 2, we use this workaround.
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
maxpos = max_positions
attn_heads = attention_heads
slopes = torch.Tensor(get_slopes(attn_heads))
if dims == 1:
# prepare alibi position linear bias. Note that wav2vec2 is non
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias = (
torch.abs(
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
)
* -1
)
elif dims == 2:
if distance == "manhattan":
df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
elif distance == "euclidean":
df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
n = math.sqrt(max_positions)
assert n.is_integer(), n
n = int(n)
pos_bias = torch.zeros((max_positions, max_positions))
for i in range(n):
for j in range(n):
for k in range(n):
for l in range(n):
new_x = i * n + j
new_y = k * n + l
pos_bias[new_x, new_y] = -df(i, j, k, l)
else:
raise Exception(f"unsupported number of alibi dims: {dims}")
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
attn_heads, -1, -1
)
return alibi_bias
def get_alibi_bias(
alibi_biases,
batch_size,
time_steps,
heads,
dtype,
device,
dims=1,
distance="manhattan",
):
cache_key = f"{dims}_{heads}_{distance}"
buffered = alibi_biases.get(cache_key, None)
target_size = heads * batch_size
if (
buffered is None
or buffered.size(0) < target_size
or buffered.size(1) < time_steps
or buffered.dtype != dtype
or buffered.device != device
):
bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
buffered = (
get_alibi(bt, heads, dims=dims, distance=distance)
.to(dtype=dtype, device=device)
.repeat(bn, 1, 1)
)
alibi_biases[cache_key] = buffered
b = buffered[:target_size, :time_steps, :time_steps]
b = b.view(batch_size, heads, time_steps, time_steps)
return b
def _learned_alibi_bias(
alibi_bias,
batch_size,
time_steps,
heads,
scale,
dtype,
device,
):
assert alibi_bias.size(1) == heads, alibi_bias.shape
assert alibi_bias.dtype == dtype, alibi_bias.dtype
assert alibi_bias.device == device, alibi_bias.device
if alibi_bias.size(-1) < time_steps:
psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
return alibi_bias[..., :time_steps, :time_steps]
def masked_alibi(alibi_bias, mask_info):
H = alibi_bias.size(1)
orig_bias = alibi_bias
index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
alibi_bias = torch.gather(
orig_bias,
dim=-2,
index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
)
alibi_bias = torch.gather(
alibi_bias,
dim=-1,
index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
)
return alibi_bias
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from functools import partial
from dataclasses import dataclass
from typing import Callable, Dict, Optional
from timm.models.layers import to_2tuple
from fairseq.tasks import FairseqTask
from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed
from .base import (
D2vModalityConfig,
ModalitySpecificEncoder,
get_alibi_bias,
MaskSeed,
)
from .modules import (
BlockEncoder,
Decoder2d,
FixedPositionalEncoder,
TransformerDecoder,
EncDecTransformerDecoder,
)
from examples.data2vec.data.modality import Modality
@dataclass
class D2vImageConfig(D2vModalityConfig):
type: Modality = Modality.IMAGE
input_size: int = 224
in_chans: int = 3
patch_size: int = 16
embed_dim: int = 768
alibi_dims: int = 2
alibi_distance: str = "manhattan"
fixed_positions: bool = True
transformer_decoder: bool = False
enc_dec_transformer: bool = False
class ImageEncoder(ModalitySpecificEncoder):
modality_cfg: D2vImageConfig
def __init__(
self,
modality_cfg: D2vImageConfig,
embed_dim: int,
make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
img_size = to_2tuple(modality_cfg.input_size)
patch_size = to_2tuple(modality_cfg.patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
local_encoder = PatchEmbed(
modality_cfg.input_size,
modality_cfg.patch_size,
modality_cfg.in_chans,
modality_cfg.embed_dim,
)
w = local_encoder.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
if modality_cfg.embed_dim != embed_dim:
local_encoder = nn.Sequential(
local_encoder,
nn.Linear(modality_cfg.embed_dim, embed_dim),
)
project_features = nn.Identity()
pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim), requires_grad=False
)
side_n = int(num_patches ** 0.5)
emb = get_2d_sincos_pos_embed(
pos_embed.shape[-1],
side_n,
cls_token=False,
)
pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
fixed_positional_encoder = (
FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None
)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim) if not layer_norm_first else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout,
)
if modality_cfg.transformer_decoder:
if modality_cfg.enc_dec_transformer:
decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim)
else:
dec_enc = BlockEncoder(
nn.ModuleList(
make_block(0, modality_cfg.decoder.decoder_dim, 8)
for _ in range(modality_cfg.decoder.decoder_layers)
),
None,
layer_norm_first,
0,
0,
)
decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc)
else:
decoder = (
Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(
get_alibi_bias,
alibi_biases=alibi_biases,
heads=modality_cfg.num_alibi_heads,
dims=modality_cfg.alibi_dims,
distance=modality_cfg.alibi_distance,
)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=project_features,
fixed_positional_encoder=fixed_positional_encoder,
relative_positional_encoder=None,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def reset_parameters(self):
super().reset_parameters()
if self.decoder is not None:
self.decoder.reset_parameters()
@torch.no_grad()
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.modality_cfg.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
return x
@torch.no_grad()
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.modality_cfg.patch_size
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def compute_mask(
self,
x,
padding_mask,
mask_seed: Optional[MaskSeed],
apply,
shape=None,
precomputed_mask=None,
):
mlen = self.modality_cfg.mask_length
if mlen <= 1:
return super().compute_mask(
x, padding_mask, mask_seed, apply, precomputed_mask
)
if precomputed_mask is not None:
mask = precomputed_mask
else:
from fairseq.data.data_utils import compute_block_mask_2d
if shape is not None:
B, L, D = shape
else:
B, L, D = x.shape
mask = compute_block_mask_2d(
shape=(B, L),
mask_prob=self.modality_cfg.mask_prob,
mask_length=self.modality_cfg.mask_length,
mask_prob_adjust=self.modality_cfg.mask_prob_adjust,
inverse_mask=self.modality_cfg.inverse_mask,
require_same_masks=True,
mask_dropout=self.modality_cfg.mask_dropout,
)
mask_info = self.make_maskinfo(x, mask, shape)
if apply:
x = self.apply_mask(x, mask_info)
return x, mask_info
def decoder_input(self, x, mask_info):
if (
not self.modality_cfg.transformer_decoder
or not self.modality_cfg.enc_dec_transformer
):
return super().decoder_input(x, mask_info)
inp_drop = self.modality_cfg.decoder.input_dropout
if inp_drop > 0:
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
kv = x[:, self.modality_cfg.num_extra_tokens :]
assert self.fixed_positional_encoder is not None
pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1)
mask = mask_info.mask.bool()
if self.modality_cfg.decoder.add_positions_all:
kv = kv + pos[~mask].view(kv.shape)
q = pos[mask].view(x.size(0), -1, x.size(-1))
return q, kv
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dataclasses import dataclass
from fairseq.modules import (
LayerNorm,
SamePad,
SamePad2d,
TransposeLast,
)
@dataclass
class D2vDecoderConfig:
decoder_dim: int = 384
decoder_groups: int = 16
decoder_kernel: int = 5
decoder_layers: int = 5
input_dropout: float = 0.1
add_positions_masked: bool = False
add_positions_all: bool = False
decoder_residual: bool = True
projection_layers: int = 1
projection_ratio: float = 2.0
class FixedPositionalEncoder(nn.Module):
def __init__(self, pos_embed):
super().__init__()
self.positions = pos_embed
def forward(self, x, padding_mask):
return self.positions
class TextFeatPositionalEncoder(nn.Module):
"""
Original encoder expects (B, T) long input. This module wraps it to take
local_encoder output which are (B, T, D) float tensors
"""
def __init__(self, pos_encoder):
super().__init__()
self.pos_encoder = pos_encoder
def forward(self, x, padding_mask):
# assume padded token embeddings are 0s
# TODO: consider using padding_mask as input
return self.pos_encoder(x[..., 0])
class BlockEncoder(nn.Module):
def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
super().__init__()
self.blocks = blocks
self.norm = norm_layer
self.layer_norm_first = layer_norm_first
self.layerdrop = layerdrop
self.dropout = nn.Dropout(dropout, inplace=True)
def forward(self, x, padding_mask, alibi_bias, alibi_scale):
if self.norm is not None and not self.layer_norm_first:
x = self.norm(x)
x = self.dropout(x)
for i, blk in enumerate(self.blocks):
if (
not self.training
or self.layerdrop == 0
or (np.random.random() > self.layerdrop)
):
ab = alibi_bias
if ab is not None and alibi_scale is not None:
scale = (
alibi_scale[i]
if alibi_scale.size(0) > 1
else alibi_scale.squeeze(0)
)
ab = ab * scale.type_as(ab)
x, _ = blk(x, padding_mask, ab)
if self.norm is not None and self.layer_norm_first:
x = self.norm(x)
return x
class DecoderBase(nn.Module):
decoder_cfg: D2vDecoderConfig
def __init__(self, cfg: D2vDecoderConfig):
super().__init__()
self.decoder_cfg = cfg
def reset_parameters(self):
for mod in self.proj.modules():
if isinstance(mod, nn.Linear):
mod.reset_parameters()
def add_residual(self, x, residual, i, mask_info):
if (
residual is None
or not self.decoder_cfg.decoder_residual
or residual.size(1) != x.size(1)
):
return x
ret = x + residual
return ret
class Decoder1d(DecoderBase):
def __init__(self, cfg: D2vDecoderConfig, input_dim):
super().__init__(cfg)
def make_block(in_dim):
block = [
nn.Conv1d(
in_dim,
cfg.decoder_dim,
kernel_size=cfg.decoder_kernel,
padding=cfg.decoder_kernel // 2,
groups=cfg.decoder_groups,
),
SamePad(cfg.decoder_kernel),
TransposeLast(),
LayerNorm(cfg.decoder_dim, elementwise_affine=False),
TransposeLast(),
nn.GELU(),
]
return nn.Sequential(*block)
self.blocks = nn.Sequential(
*[
make_block(input_dim if i == 0 else cfg.decoder_dim)
for i in range(cfg.decoder_layers)
]
)
projs = []
curr_dim = cfg.decoder_dim
for i in range(cfg.projection_layers - 1):
next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
projs.append(nn.Linear(curr_dim, next_dim))
projs.append(nn.GELU())
curr_dim = next_dim
projs.append(nn.Linear(curr_dim, input_dim))
if len(projs) == 1:
self.proj = projs[0]
else:
self.proj = nn.Sequential(*projs)
def forward(self, x, mask_info):
x = x.transpose(1, 2)
residual = x
for i, layer in enumerate(self.blocks):
x = layer(x)
x = self.add_residual(x, residual, i, mask_info)
residual = x
x = x.transpose(1, 2)
x = self.proj(x)
return x
class Decoder2d(DecoderBase):
def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size):
super().__init__(cfg)
self.h_size = h_size
self.w_size = w_size
def make_block(in_dim):
block = [
nn.Conv2d(
in_dim,
cfg.decoder_dim,
kernel_size=cfg.decoder_kernel,
padding=cfg.decoder_kernel // 2,
groups=cfg.decoder_groups,
),
SamePad2d(cfg.decoder_kernel),
TransposeLast(tranpose_dim=-3),
LayerNorm(cfg.decoder_dim, elementwise_affine=False),
TransposeLast(tranpose_dim=-3),
nn.GELU(),
]
return nn.Sequential(*block)
self.blocks = nn.Sequential(
*[
make_block(input_dim if i == 0 else cfg.decoder_dim)
for i in range(cfg.decoder_layers)
]
)
self.proj = nn.Linear(cfg.decoder_dim, input_dim)
def forward(self, x, mask_info):
B, T, C = x.shape
x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size)
residual = x
for i, layer in enumerate(self.blocks):
x = layer(x)
x = self.add_residual(x, residual, i, mask_info)
residual = x
x = x.reshape(B, -1, T).transpose(1, 2)
x = self.proj(x)
return x
class TransformerDecoder(nn.Module):
decoder_cfg: D2vDecoderConfig
def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder):
super().__init__()
self.decoder_cfg = cfg
self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
self.encoder = encoder
self.proj = nn.Linear(cfg.decoder_dim, input_dim)
def reset_parameters(self):
from fairseq.modules.transformer_sentence_encoder import init_bert_params
self.apply(init_bert_params)
def forward(self, x, mask_info):
x = self.input_proj(x)
x = self.encoder(x, None, None, 1)
x = self.proj(x)
return x
class AltBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
mlp_drop=0.0,
post_mlp_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
layer_norm_first=True,
ffn_targets=False,
cosine_attention=False,
):
super().__init__()
self.layer_norm_first = layer_norm_first
self.ffn_targets = ffn_targets
from timm.models.vision_transformer import DropPath, Mlp
self.norm1 = norm_layer(dim)
self.attn = AltAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
cosine_attention=cosine_attention,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop,
)
self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
def forward(self, x, padding_mask=None, alibi_bias=None):
if self.layer_norm_first:
x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
r = x = self.mlp(self.norm2(x))
t = x
x = r + self.drop_path(self.post_mlp_dropout(x))
if not self.ffn_targets:
t = x
else:
x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
r = x = self.norm1(x)
x = self.mlp(x)
t = x
x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
if not self.ffn_targets:
t = x
return x, t
class AltAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
cosine_attention=False,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.cosine_attention = cosine_attention
if cosine_attention:
self.logit_scale = nn.Parameter(
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
)
def forward(self, x, padding_mask=None, alibi_bias=None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
dtype = q.dtype
if self.cosine_attention:
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(
self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
).exp()
attn = attn * logit_scale
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if alibi_bias is not None:
attn = attn.type_as(alibi_bias)
attn[:, : alibi_bias.size(1)] += alibi_bias
if padding_mask is not None and padding_mask.any():
attn = attn.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2) #
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EncDecAttention(nn.Module):
def __init__(
self,
q_dim,
kv_dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
cosine_attention=False,
):
super().__init__()
self.num_heads = num_heads
head_dim = q_dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias)
self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(q_dim, q_dim)
self.proj_drop = nn.Dropout(proj_drop)
self.cosine_attention = cosine_attention
if cosine_attention:
self.logit_scale = nn.Parameter(
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
)
def forward(self, q, kv, padding_mask=None, alibi_bias=None):
B, N, C = q.shape
q = (
self.q_proj(q)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
) # B x H x L x D
kv = (
self.kv_proj(kv)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
) # kv x B x H x L x D
k, v = (
kv[0],
kv[1],
) # make torchscript happy (cannot use tensor as tuple)
dtype = q.dtype
if self.cosine_attention:
# cosine attention
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
logit_scale = torch.clamp(
self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
).exp()
attn = attn * logit_scale
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if alibi_bias is not None:
attn = attn.type_as(alibi_bias)
attn[:, : alibi_bias.size(1)] += alibi_bias
if padding_mask is not None and padding_mask.any():
attn = attn.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2) #
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EncDecBlock(nn.Module):
def __init__(
self,
q_dim,
kv_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
mlp_drop=0.0,
post_mlp_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
layer_norm_first=True,
cosine_attention=False,
first_residual=True,
):
super().__init__()
self.layer_norm_first = layer_norm_first
from timm.models.vision_transformer import DropPath, Mlp
self.norm1 = norm_layer(q_dim)
self.attn = EncDecAttention(
q_dim,
kv_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
cosine_attention=cosine_attention,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(q_dim)
mlp_hidden_dim = int(q_dim * mlp_ratio)
self.mlp = Mlp(
in_features=q_dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop,
)
self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
self.first_residual = first_residual
def forward(self, q, kv, padding_mask=None, alibi_bias=None):
r = q if self.first_residual else 0
if self.layer_norm_first:
x = r + self.drop_path(
self.attn(self.norm1(q), kv, padding_mask, alibi_bias)
)
r = x = self.mlp(self.norm2(x))
x = r + self.drop_path(self.post_mlp_dropout(x))
else:
x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias))
r = x = self.norm1(x)
x = self.mlp(x)
x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
return x
class EncDecTransformerDecoder(nn.Module):
def __init__(self, cfg: D2vDecoderConfig, input_dim):
super().__init__()
self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
self.blocks = nn.Sequential(
*[
EncDecBlock(
q_dim=cfg.decoder_dim,
kv_dim=input_dim,
num_heads=8,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
mlp_drop=0.0,
post_mlp_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
layer_norm_first=False,
cosine_attention=False,
first_residual=i > 0,
)
for i in range(cfg.decoder_layers)
]
)
self.proj = nn.Linear(cfg.decoder_dim, input_dim)
def reset_parameters(self):
from fairseq.modules.transformer_sentence_encoder import init_bert_params
self.apply(init_bert_params)
def forward(self, x, kv):
x = self.input_proj(x)
for i, layer in enumerate(self.blocks):
x = layer(x, kv)
x = self.proj(x)
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Optional
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
from fairseq.tasks import FairseqTask
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
from .modules import BlockEncoder, Decoder1d
from examples.data2vec.data.modality import Modality
@dataclass
class D2vTextConfig(D2vModalityConfig):
type: Modality = Modality.TEXT
max_source_positions: int = 512
learned_pos: bool = True
dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
no_scale_embedding: bool = True
layernorm_embedding: bool = True
no_token_positional_embeddings: bool = False
class TextEncoder(ModalitySpecificEncoder):
modality_cfg: D2vTextConfig
def __init__(
self,
modality_cfg: D2vTextConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
self.pad_idx = task.source_dictionary.pad()
self.vocab_size = len(task.source_dictionary)
local_encoder = TextLocalEncoder(
vocab_size=self.vocab_size,
embed_dim=embed_dim,
max_source_positions=modality_cfg.max_source_positions,
pad_idx=self.pad_idx,
no_scale_embedding=modality_cfg.no_scale_embedding,
layernorm_embedding=modality_cfg.layernorm_embedding,
dropout=modality_cfg.dropout,
no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
learned_pos=modality_cfg.learned_pos,
)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim)
if not layer_norm_first and modality_cfg.prenet_depth > 0
else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
)
decoder = (
Decoder1d(modality_cfg.decoder, embed_dim)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=nn.Identity(),
fixed_positional_encoder=None,
relative_positional_encoder=None,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def reset_parameters(self):
super().reset_parameters()
def convert_padding_mask(self, x, padding_mask):
if padding_mask is None or padding_mask.size(1) == x.size(1):
return padding_mask
diff = self.downsample - padding_mask.size(1) % self.downsample
if 0 < diff < self.downsample:
padding_mask = F.pad(padding_mask, (0, diff), value=True)
padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
padding_mask = padding_mask.all(-1)
if padding_mask.size(1) > x.size(1):
padding_mask = padding_mask[:, : x.size(1)]
assert x.size(1) == padding_mask.size(
1
), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
return padding_mask
class TextLocalEncoder(nn.Module):
def __init__(
self,
vocab_size,
embed_dim,
max_source_positions,
pad_idx,
no_scale_embedding,
layernorm_embedding,
dropout,
no_token_positional_embeddings,
learned_pos,
):
super().__init__()
self.pad_idx = pad_idx
self.dropout_module = FairseqDropout(dropout)
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
max_source_positions,
embed_dim,
pad_idx,
learned=learned_pos,
)
if not no_token_positional_embeddings
else None
)
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
self.layernorm_embedding = None
if layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
def forward(self, src_tokens):
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = x + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x
import math
import torch
def get_alibi(
max_positions: int,
attention_heads: int,
):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
# In the paper, we only train models that have 2^a heads for some
# a. This function has some good properties that only occur when
# the input is a power of 2. To maintain that even when the number
# of heads is not a power of 2, we use this workaround.
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
maxpos = max_positions
attn_heads = attention_heads
slopes = torch.Tensor(get_slopes(attn_heads))
# prepare alibi position linear bias. Note that wav2vec2 is non
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias = (
torch.abs(
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
)
* -1
)
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
attn_heads, -1, -1
)
return alibi_bias
def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T):
alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T)
H = alibi_bias.size(1)
alibi_mask = mask_indices.unsqueeze(1)
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1))
alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T)
M = alibi_bias.size(-2)
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2))
alibi_bias = alibi_bias.view(-1, M, M)
return alibi_bias
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
def get_parser():
parser = argparse.ArgumentParser(description="convert audioset labels")
# fmt: off
parser.add_argument('in_file', help='audioset csv file to convert')
parser.add_argument('--manifest', required=True, metavar='PATH', help='wav2vec-like manifest')
parser.add_argument('--descriptors', required=True, metavar='PATH', help='path to label descriptor file')
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted labels')
# fmt: on
return parser
def main():
parser = get_parser()
args = parser.parse_args()
label_descriptors = {}
with open(args.descriptors, "r") as ldf:
next(ldf)
for line in ldf:
if line.strip() == "":
continue
items = line.split(",")
assert len(items) > 2, line
idx = items[0]
lbl = items[1]
assert lbl not in label_descriptors, lbl
label_descriptors[lbl] = idx
labels = {}
with open(args.in_file, "r") as ifd:
for line in ifd:
if line.lstrip().startswith("#"):
continue
items = line.rstrip().split(",")
id = items[0].strip()
start = items[1].strip()
end = items[2].strip()
lbls = [label_descriptors[it.strip(' "')] for it in items[3:]]
labels[id] = [start, end, ",".join(lbls)]
with open(args.manifest, "r") as mf, open(args.output, "w") as of:
next(mf)
for line in mf:
path, _ = line.split("\t")
id = os.path.splitext(os.path.basename(path))[0]
lbl = labels[id]
print("\t".join(lbl), file=of)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment