Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
dataset:
bert_name: bert-base-uncased
caption_pkl_path: data/how2/raw_caption_dedup.pkl
use_fast: true
target_dir: data/feat/feat_how2_s3d_shard_small
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pickle
import os
import argparse
import numpy as np
from torch.utils.data import Dataset, DataLoader
from mmpt.processors import PKLJSONStrTextProcessor
from mmpt.utils import ShardedTensor, recursive_config
class TokenizerDataset(Dataset):
def __init__(self, config):
self.text_processor = PKLJSONStrTextProcessor(config)
self.video_ids = list(self.text_processor.data.keys())
def __getitem__(self, idx):
video_id = self.video_ids[idx]
return video_id, self.text_processor(video_id)
def __len__(self):
return len(self.video_ids)
def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32):
startends = []
caps_ids = []
for video_id in video_ids:
caption = captions[video_id]
startend = []
cap_ids = []
for start, end, cap in zip(
caption["start"], caption["end"], caption["cap"]):
startend.append(np.array([start, end]).astype("float32"))
cap_id = np.full((max_cap_len,), -1, dtype=np.int32)
cap = cap[:max_cap_len]
cap_id[:len(cap)] = cap
cap_ids.append(cap_id)
startends.append(np.stack(startend))
caps_ids.append(np.stack(cap_ids))
startends = ShardedTensor.from_list(startends)
target_path = os.path.join(
target_dir,
prefix + split + "_" + str(shard_idx)
)
print("save to", target_path)
startends.save(target_path + ".startends")
caps_ids = ShardedTensor.from_list(caps_ids)
caps_ids.save(target_path + ".caps_ids")
def sharding(config, out_file):
with open(out_file, "rb") as fr:
captions = pickle.load(fr)
target_dir = config.target_dir
prefix = os.path.basename(
os.path.splitext(config.caption_pkl_path)[0]
) + "." + config.bert_name + "."
for split in ["train", "val"]:
target_path = os.path.join(target_dir, split + "_meta")
with open(target_path + ".pkl", "rb") as fr:
meta = pickle.load(fr)
print("load meta", target_path, len(meta))
for shard_id in meta:
numpify(
shard_id, meta[shard_id], captions,
target_dir, split, prefix
)
def tokenize(config, out_file):
def collator(samples):
return samples
dataset = TokenizerDataset(config)
data = {}
for idx, batch in enumerate(
DataLoader(dataset, collate_fn=collator, num_workers=16)):
for video_id, caption in batch:
data[video_id] = caption
if idx % 5000 == 0:
print(idx)
with open(out_file, "wb") as fw:
pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL)
def main(args):
config = recursive_config(args.config).dataset
out_file = os.path.splitext(config.caption_pkl_path)[0] \
+ "." + config.bert_name + ".pkl"
if not os.path.isfile(out_file):
tokenize(config, out_file)
sharding(config, out_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="pretokenize (raw_)caption.json into pkl.")
parser.add_argument('config', type=str)
args = parser.parse_args()
main(args)
# Copyright Howto100M authors.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch as th
import torch.nn.functional as F
import math
import numpy as np
import argparse
from torch.utils.data import DataLoader
from model import get_model
from preprocessing import Preprocessing
from random_sequence_shuffler import RandomSequenceSampler
from tqdm import tqdm
from pathbuilder import PathBuilder
from videoreader import VideoLoader
parser = argparse.ArgumentParser(description='Easy video feature extractor')
parser.add_argument('--vdir', type=str)
parser.add_argument('--fdir', type=str)
parser.add_argument('--hflip', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=64,
help='batch size')
parser.add_argument('--type', type=str, default='2d',
help='CNN type')
parser.add_argument('--half_precision', type=int, default=0,
help='output half precision float')
parser.add_argument('--num_decoding_thread', type=int, default=4,
help='Num parallel thread for video decoding')
parser.add_argument('--l2_normalize', type=int, default=1,
help='l2 normalize feature')
parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth',
help='Resnext model path')
parser.add_argument('--vmz_model_path', type=str, default='model/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth',
help='vmz model path')
args = parser.parse_args()
# TODO: refactor all args into config. (current code is from different people.)
CONFIGS = {
"2d": {
"fps": 1,
"size": 224,
"centercrop": False,
"shards": 0,
},
"3d": {
"fps": 24,
"size": 112,
"centercrop": True,
"shards": 0,
},
"s3d": {
"fps": 30,
"size": 224,
"centercrop": True,
"shards": 0,
},
"vmz": {
"fps": 24,
"size": 112,
"centercrop": True,
"shards": 0,
},
"vae": {
"fps": 2,
"size": 256,
"centercrop": True,
"shards": 100,
}
}
config = CONFIGS[args.type]
video_dirs = args.vdir
feature_dir = args.fdir
video_dict = PathBuilder.build(video_dirs, feature_dir, ".npy", config["shards"])
dataset = VideoLoader(
video_dict=video_dict,
framerate=config["fps"],
size=config["size"],
centercrop=config["centercrop"],
hflip=args.hflip
)
n_dataset = len(dataset)
sampler = RandomSequenceSampler(n_dataset, 10)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_decoding_thread,
sampler=sampler if n_dataset > 10 else None,
)
preprocess = Preprocessing(args.type)
model = get_model(args)
with th.no_grad():
for k, data in tqdm(enumerate(loader), total=loader.__len__(), ascii=True):
input_file = data['input'][0]
output_file = data['output'][0]
if len(data['video'].shape) > 3:
video = data['video'].squeeze()
if len(video.shape) == 4:
video = preprocess(video)
n_chunk = len(video)
if args.type == 'vmz':
n_chunk = math.ceil(n_chunk/float(3))
features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
elif args.type == 's3d':
features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
elif args.type == "vae":
features = th.cuda.LongTensor(n_chunk, 1024).fill_(0)
else:
features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0)
n_iter = int(math.ceil(n_chunk / float(args.batch_size)))
for i in range(n_iter):
factor = 1
if args.type == 'vmz':
factor = 3
min_ind = factor * i * args.batch_size
max_ind = factor * (i + 1) * args.batch_size
video_batch = video[min_ind:max_ind:factor].cuda()
if args.type == '2d':
batch_features = model(video_batch) # (51, 487), (51, 512)
elif args.type == 's3d':
batch_features = model(video_batch)
batch_features = batch_features['video_embedding']
elif args.type == "vae":
# image_code.
batch_features = model(video_batch)
else:
batch_pred, batch_features = model(video_batch) # (51, 487), (51, 512)
if args.l2_normalize:
batch_features = F.normalize(batch_features, dim=1)
features[i*args.batch_size:(i+1)*args.batch_size] = batch_features
features = features.cpu().numpy()
if args.half_precision:
if args.type == "vae":
features = features.astype(np.int16)
else:
features = features.astype('float16')
else:
if args.type == "vae":
features = features.astype(np.int32)
else:
features = features.astype('float32')
np.save(output_file, features)
else:
print('Video {} error.'.format(input_file))
#!/bin/bash
python scripts/video_feature_extractor/extract.py \
--vdir <path_to_video_folder> \
--fdir data/feat/feat_how2_s3d \
--type=s3d --num_decoding_thread=4 \
--batch_size 32 --half_precision 1
# Copyright (c) Howto100M authors and Facebook, Inc. All Rights Reserved
import torch as th
from torch import nn
class GlobalAvgPool(nn.Module):
def __init__(self):
super(GlobalAvgPool, self).__init__()
def forward(self, x):
return th.mean(x, dim=[-2, -1])
def get_model(args):
assert args.type in ['2d', '3d', 'vmz', 's3d', 'vae']
if args.type == '2d':
print('Loading 2D-ResNet-152 ...')
import torchvision.models as models
model = models.resnet152(pretrained=True)
model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool())
model = model.cuda()
elif args.type == 'vmz':
print('Loading VMZ ...')
from vmz34 import r2plus1d_34
model = r2plus1d_34(pretrained_path=args.vmz_model_path, pretrained_num_classes=487)
model = model.cuda()
elif args.type == 's3d':
# we use one copy of s3d instead of dup another one for feature extraction.
from mmpt.processors.models.s3dg import S3D
model = S3D('pretrained_models/s3d_dict.npy', 512)
model.load_state_dict(th.load('pretrained_models/s3d_howto100m.pth'))
model = model.cuda()
elif args.type == '3d':
print('Loading 3D-ResneXt-101 ...')
from videocnn.models import resnext
model = resnext.resnet101(
num_classes=400,
shortcut_type='B',
cardinality=32,
sample_size=112,
sample_duration=16,
last_fc=False)
model = model.cuda()
model_data = th.load(args.resnext101_model_path)
model.load_state_dict(model_data)
elif args.type == 'vae':
from openaivae import OpenAIParallelDiscreteVAE
model = OpenAIParallelDiscreteVAE()
model = model.cuda()
else:
raise ValueError("model not supported yet.")
model.eval()
print('loaded')
return model
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import urllib.parse
import json
import pandas as pd
from tqdm import tqdm
# TODO: extending to other datasets.
supported_formats = {}
class PathBuilder(object):
@classmethod
def build(cls, video_dirs, feature_dir, ext, shards=0, split=None):
meta_fn = os.path.join(feature_dir, "meta_plan.json")
os.makedirs(feature_dir, exist_ok=True)
if os.path.isfile(meta_fn):
with open(meta_fn) as fr:
meta = json.load(fr)
return meta
print("searching videos...")
video_id_to_path = {}
for video_dir in video_dirs.split(","):
# TODO: add supports of recursive listdir.
if video_dir in supported_formats:
supported_formats[video_dir].load(video_dir, video_id_to_path)
else:
for idx, fn in enumerate(tqdm(os.listdir(video_dir))):
video_fn = os.path.join(video_dir, fn)
if os.path.isfile(video_fn):
video_id = os.path.splitext(fn)[0]
video_id_to_path[video_id] = video_fn
elif os.path.isdir(video_fn):
# shards of folders.
shard_dir = video_fn
for idx, fn in enumerate(os.listdir(shard_dir)):
video_fn = os.path.join(shard_dir, fn)
if os.path.isfile(video_fn):
video_id = os.path.splitext(fn)[0]
video_id_to_path[video_id] = video_fn
video_path, feature_path = [], []
valid_ext = set()
for idx, video_id in enumerate(video_id_to_path):
video_path.append(video_id_to_path[video_id])
if ext is None:
# use original file ext for format compatibility.
video_id_to_path[video_id]
path = urllib.parse.urlparse(video_id_to_path[video_id]).path
ext = os.path.splitext(path)[1]
if ext not in valid_ext:
valid_ext.add(ext)
print("adding", ext)
if shards:
shard_id = str(idx % shards)
feature_fn = os.path.join(
feature_dir, shard_id, video_id + ext)
else:
feature_fn = os.path.join(
feature_dir, video_id + ext)
feature_path.append(feature_fn)
print("targeting", len(feature_path), "videos")
meta = {
"video_path": video_path, "feature_path": feature_path}
with open(meta_fn, "w") as fw:
json.dump(meta, fw)
if split is not None:
splits = split.split("/")
assert len(splits) == 2
cur, total = int(splits[0]), int(splits[1])
assert cur < total
import math
chunk = math.ceil(len(meta["video_path"]) / total)
start = cur * chunk
end = (cur + 1) * chunk
meta = {
"video_path": meta["video_path"][start:end],
"feature_path": meta["feature_path"][start:end]
}
return meta
# Copyright Howto100m authors.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch as th
class Normalize(object):
def __init__(self, mean, std):
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
self.std = th.FloatTensor(std).view(1, 3, 1, 1)
def __call__(self, tensor):
tensor = (tensor - self.mean) / (self.std + 1e-8)
return tensor
class Preprocessing(object):
def __init__(self, type):
self.type = type
if type == '2d':
self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
elif type == '3d':
self.norm = Normalize(mean=[110.6, 103.2, 96.3], std=[1.0, 1.0, 1.0])
elif type == 'vmz':
self.norm = Normalize(mean=[110.201, 100.64, 95.997], std=[58.1489, 56.4701, 55.3324])
def _zero_pad(self, tensor, size):
n = size - len(tensor) % size
if n == size:
return tensor
else:
z = th.zeros(n, tensor.shape[1], tensor.shape[2], tensor.shape[3])
return th.cat((tensor, z), 0)
def __call__(self, tensor):
if self.type == '2d':
tensor = tensor / 255.0
tensor = self.norm(tensor)
elif self.type == 'vmz':
#tensor = self._zero_pad(tensor, 8)
tensor = self._zero_pad(tensor, 10)
tensor = self.norm(tensor)
#tensor = tensor.view(-1, 8, 3, 112, 112)
tensor = tensor.view(-1, 10, 3, 112, 112)
tensor = tensor.transpose(1, 2)
elif self.type == '3d':
tensor = self._zero_pad(tensor, 16)
tensor = self.norm(tensor)
tensor = tensor.view(-1, 16, 3, 112, 112)
tensor = tensor.transpose(1, 2)
elif self.type == 's3d':
tensor = tensor / 255.0
tensor = self._zero_pad(tensor, 30)
tensor = tensor.view(-1, 30, 3, 224, 224) # N x 30 x 3 x H x W
tensor = tensor.transpose(1, 2) # N x 3 x 30 x H x W
# for vae do nothing
return tensor
# Copyright (c) Facebook, Inc. All Rights Reserved
import numpy as np
from torch.utils.data.sampler import Sampler
class RandomSequenceSampler(Sampler):
def __init__(self, n_sample, seq_len):
self.n_sample = n_sample
self.seq_len = seq_len
def _pad_ind(self, ind):
zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len)
ind = np.concatenate((ind, zeros))
return ind
def __iter__(self):
idx = np.arange(self.n_sample)
if self.n_sample % self.seq_len != 0:
idx = self._pad_ind(idx)
idx = np.reshape(idx, (-1, self.seq_len))
np.random.shuffle(idx)
idx = np.reshape(idx, (-1))
return iter(idx.astype(int))
def __len__(self):
return self.n_sample + (self.seq_len - self.n_sample % self.seq_len)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import os
import pickle
from mmpt.utils import ShardedTensor
class Shard(object):
def __init__(
self,
vfeat_dir,
tfeat_dir,
target_dir,
file_paths,
shard_size=4096
):
self.vfeat_dir = vfeat_dir
self.tfeat_dir = tfeat_dir
self.target_dir = target_dir
self.video_ids = {}
for split, file_path in zip(["train", "val"], file_paths):
with open(file_path) as fr:
self.video_ids[split] = [
line.strip() for line in fr.readlines()]
self.shard_size = shard_size
def __call__(self, split="train"):
for split in ["train", "val"]:
meta = {}
for shard_idx, shard_offset in enumerate(
range(0, len(self.video_ids[split]), self.shard_size)
):
print(shard_idx)
meta_shard = []
video_shard = []
for video_id in self.video_ids[split][shard_offset:shard_offset+self.shard_size]:
meta_shard.append(video_id)
npy_file = os.path.join(self.vfeat_dir, video_id + ".npy")
video_shard.append(np.load(npy_file))
meta[shard_idx] = meta_shard
video_shard = ShardedTensor.from_list(video_shard)
target_path = os.path.join(
self.target_dir, split + "_" + str(shard_idx))
video_shard.save(target_path)
target_path = os.path.join(self.target_dir, split + "_meta")
with open(target_path + ".pkl", "wb") as fw:
pickle.dump(meta, fw, pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
shard = Shard(
"data/feat/feat_how2_s3d",
"data/how2/raw_caption_dedup.bert-base-uncased",
"data/feat/feat_how2_s3d_shard_small",
["data/how2/how2_s3d_train.lst", "data/how2/how2_s3d_val.lst"]
)
shard()
# Copyright Howto100M authors.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch as th
import pandas as pd
import os
import numpy as np
import ffmpeg
import random
from torch.utils.data import Dataset
class VideoLoader(Dataset):
"""modified from how2's video_feature_extractor."""
def __init__(
self,
csv=None,
video_dict=None,
framerate=1,
size=112,
centercrop=False,
hflip=False,
**kwargs
):
if csv is None and video_dict is None:
raise ValueError("csv and video_dict cannot be both None.")
if csv is not None:
self.csv = pd.read_csv(csv)
if video_dict is not None:
self.csv = pd.DataFrame.from_dict(video_dict)
self.centercrop = centercrop
self.size = size
self.framerate = framerate
self.hflip = hflip
def __len__(self):
return len(self.csv)
def _get_video_dim(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
return height, width
def _get_video_info(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
return video_stream
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
return self._decode(output_file, video_path)
def _decode(self, output_file, video_path):
if not(os.path.isfile(output_file)) and os.path.isfile(video_path):
try:
h, w = self._get_video_dim(video_path)
except Exception:
print('ffprobe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
height, width = self._get_output_dim(h, w)
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=self.framerate)
.filter('scale', width, height)
)
if self.hflip:
cmd = cmd.filter('hflip')
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
video = self._run(cmd, output_file)
except Exception:
video = th.zeros(1)
else:
video = th.zeros(1)
return {'video': video, 'input': video_path, 'output': output_file}
def _run(self, cmd, output_file):
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
video = th.from_numpy(video.astype('float32'))
return video.permute(0, 3, 1, 2)
class VideoVerifier(VideoLoader):
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
try:
return self._get_video_info(video_path)
except Exception:
# print('ffprobe failed at: {}'.format(video_path))
return None
class VideoCompressor(VideoLoader):
def __init__(
self,
csv=None,
video_dict=None,
framerate=1,
size=112,
centercrop=False,
hflip=False,
crf=32,
**kwargs
):
super().__init__(
csv,
video_dict,
framerate,
size,
centercrop,
hflip
)
self.crf = crf
def _run(self, cmd, output_file):
out, _ = (
cmd.output(filename=output_file, crf=self.crf)
.run(quiet=True)
)
video = None
return video
class VideoDownloader(VideoCompressor):
"""download"""
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
if not(os.path.isfile(output_file)):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
cmd = "wget -O" + output_file + " " + video_path
# import subprocess
# subprocess.check_output(
# cmd,
# stderr=subprocess.STDOUT, shell=True)
os.system(cmd)
return {'video': None, 'input': video_path, 'output': output_file}
class AvKeyframeVideoCompressor(VideoLoader):
"""extract keyframes from a video and save it as jpg.
TODO: consider to merge with `CodecProcessor`.
"""
def __init__(
self,
csv=None,
video_dict=None,
framerate=1,
size=112,
centercrop=False,
max_num_frames=5,
**kwargs
):
super().__init__(csv, video_dict, framerate, size, centercrop)
self.max_num_frames = max_num_frames
def _get_video_dim(self, video_fn):
"""decord cannot probe the size of a video, we use pyav instead."""
import av
with av.open(video_fn) as container:
height = container.streams.video[0].codec_context.height
width = container.streams.video[0].codec_context.width
return height, width
def _get_output_dim(self, height, width):
"""
keep the shorter side be `self.size`, strech the other.
"""
if height >= width:
return int(height * self.size / width), self.size
else:
return self.size, int(width * self.size / height)
def __getitem__(self, idx):
import av
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
if not(os.path.isdir(output_file)) and os.path.isfile(video_path):
try:
h, w = self._get_video_dim(video_path)
except Exception:
print('probe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
try:
height, width = self._get_output_dim(h, w)
# new for av.
with av.open(video_path) as container:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].codec_context.height = height
container.streams.video[0].codec_context.width = width
if self.framerate == 0: # keyframe.
container.streams.video[0].codec_context.skip_frame = 'NONKEY'
frames = []
for frame in container.decode(video=0):
frames.append(frame)
frames = random.sample(frames, self.max_num_frames)
os.makedirs(output_file, exist_ok=True)
for frame in frames:
frame.to_image().save(
os.path.join(
output_file,
"%04d.jpg" % frame.index))
except Exception:
print('extract failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
video = th.zeros(1)
return {'video': video, 'input': video_path, 'output': output_file}
import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="mmpt",
version="0.0.1",
author="Hu Xu, Po-yao Huang",
author_email="huxu@fb.com",
description="A package for multimodal pretraining.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/pytorch/fairseq/examples/MMPT",
packages=setuptools.find_packages(),
install_requires=[
],
classifiers=[
"Programming Language :: Python :: 3",
"License :: CC-BY-NC",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
try:
from fairseq.version import __version__ # noqa
except ImportError:
pass
# Adaptive Span
Adaptive Span is a novel self-attention mechanism that can learn its optimal
attention span. This allows us to extend significantly the maximum context size
used in Transformer, while maintaining control over their memory footprint
and computational time. It uses the Truncated BPTT technique for training,
as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
Adaptive Span was introduced by paper:
[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
which achieved state-of-the-art language modeling results at the time of publication.
We manage to reproduce their result in fairseq and keep most of the
[original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
##### 0. Setup
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
You can download the dataset, and then run:
```bash
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
```
##### 1. Train a Adaptive Span model on Enwik8
We will train a 12-layer Adaptive Span model following the [hyperparameters
used in the original
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
The following command assumes 4 GPUs, so that the total batch size is 64
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/adaptive_span \
--data ~/data/enwik8/data-bin/ \
--fp16 --fp16-no-flatten-grads --max-update 600000 \
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
--validate-interval-updates 1000 \
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
```
This should land around 1.05 on validation, 1.03 on test. You can lower the
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
improvement to the transformerXL baseline here.
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
and simulate training on 4 GPUs.
You can also reproduce the transformerXL result on enwik8 using this code base.
It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
You can try by
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/truncated_bptt \
~/data/enwik8/data-bin/ \
--task truncated_bptt_lm --fp16 --max-update 400000 \
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
--lr-scheduler cosine --warmup-updates 0 \
--lr 0.0 --lr 0.00025 --batch-size 15 \
--update-freq 1 --seed 2 --log-format json --log-interval 25 \
--fp16
```
##### 2. Evaluate
For Adaptive Span:
```bash
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
--user-dir examples/adaptive_span \
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
```
For Transformer-XL evaluation:
```bash
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
--tokens-per-sample 80 \
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
--gen-subset valid
```
*Note:* During training the model saw 512 tokens of context
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
settings from [the original
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the current directory
cur_dir = os.path.dirname(__file__)
for file in os.listdir(cur_dir):
path = os.path.join(cur_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
mod_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module(__name__ + "." + mod_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.optim import Adagrad
from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
@register_optimizer("adagrad_with_grad_clip")
class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
help='internal grad clip')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0],
"weight_decay": self.args.weight_decay,
"grad_clip": self.args.adagrad_clip,
}
@property
def supports_flat_params(self):
return False
def _clip_grad(clr, grad, group_grad_clip):
if group_grad_clip > 0:
norm = grad.norm(2).item()
if norm > group_grad_clip:
clr *= group_grad_clip / (norm + 1e-10)
return clr
class AdagradWithGradClip(Adagrad):
"""Adagrad algorithm with custom gradient clipping"""
def __init__(
self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
grad_clip=0,
):
Adagrad.__init__(
self,
params,
lr=lr,
lr_decay=lr_decay,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value,
)
self.defaults["grad_clip"] = grad_clip
self.param_groups[0].setdefault("grad_clip", grad_clip)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
state["step"] += 1
if group["weight_decay"] != 0:
if p.grad.data.is_sparse:
raise RuntimeError(
"weight_decay option is "
"not compatible with sparse "
"gradients"
)
grad = grad.add(group["weight_decay"], p.data)
clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
# clip
clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
if grad.is_sparse:
# the update is non-linear so indices must be unique
grad = grad.coalesce()
grad_indices = grad._indices()
grad_values = grad._values()
size = grad.size()
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
state["sum"].add_(make_sparse(grad_values.pow(2)))
std = state["sum"]._sparse_mask(grad)
std_values = std._values().sqrt_().add_(1e-10)
p.data.add_(-clr, make_sparse(grad_values / std_values))
else:
state["sum"].addcmul_(1, grad, grad)
std = state["sum"].sqrt().add_(1e-10)
p.data.addcdiv_(-clr, grad, std)
return loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveMask(nn.Module):
"""Soft masking function for adaptive size.
It masks out the last K values of an input. The masking value
goes from 1 to 0 gradually, so K can be learned with
back-propagation.
Args:
max_size: maximum size (i.e. input dimension)
ramp_size: size of the ramp going from 0 to 1
init_val: initial size proportion not to be masked out
shape: learn multiple sizes independent of each other
"""
def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
nn.Module.__init__(self)
self._max_size = max_size
self._ramp_size = ramp_size
self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
self.register_buffer("mask_template", mask_template)
def forward(self, x):
mask = self.mask_template.float() + self.current_val.float() * self._max_size
mask = mask / self._ramp_size + 1
mask = mask.clamp(0, 1)
if x.size(-1) < self._max_size:
# the input could have been trimmed beforehand to save computation
mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
x = (x * mask).type_as(x)
return x
def get_current_max_size(self, include_ramp=True):
current_size = math.ceil(self.current_val.max().item() * self._max_size)
if include_ramp:
current_size += self._ramp_size
current_size = max(0, min(self._max_size, current_size))
return current_size
def get_current_avg_size(self, include_ramp=True):
current_size = math.ceil(
self.current_val.float().mean().item() * self._max_size
)
if include_ramp:
current_size += self._ramp_size
current_size = max(0, min(self._max_size, current_size))
return current_size
def clamp_param(self):
"""this need to be called after each update"""
self.current_val.data.clamp_(0, 1)
class AdaptiveSpan(nn.Module):
"""Adaptive attention span for Transformerself.
This module learns an attention span length from data for each
self-attention head.
Args:
attn_span: maximum attention span
adapt_span_loss: loss coefficient for the span length
adapt_span_ramp: length of the masking ramp
adapt_span_init: initial size ratio
adapt_span_cache: adapt cache size to reduce memory usage
"""
def __init__(
self,
attn_span,
adapt_span_ramp,
adapt_span_init,
n_head,
adapt_span_layer,
**kargs
):
nn.Module.__init__(self)
self._max_span = attn_span
self._n_head = n_head
self._adapt_span_layer = adapt_span_layer
if self._adapt_span_layer:
self._mask = AdaptiveMask(
max_size=self._max_span,
ramp_size=adapt_span_ramp,
init_val=adapt_span_init,
)
else:
self._mask = AdaptiveMask(
max_size=self._max_span,
ramp_size=adapt_span_ramp,
init_val=adapt_span_init,
shape=(n_head, 1, 1),
)
def forward(self, attn, normalize=True):
"""mask attention with the right span"""
# batch and head dimensions are merged together, so separate them first
self.clamp_param()
if self._adapt_span_layer:
attn = self._mask(attn)
else:
B = attn.size(0) # batch size
M = attn.size(1) # block size
attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
attn = self._mask(attn)
attn = attn.view(B, M, -1)
return attn
def get_trim_len(self):
"""how much of memory can be trimmed to reduce computation"""
L = self._max_span
trim_len = min(L - 1, L - self._mask.get_current_max_size())
# too fine granularity might be bad for the memory management
trim_len = math.floor(trim_len / 64) * 64
return trim_len
def trim_memory(self, query, key, value, key_pe):
"""trim out unnecessary memory beforehand to reduce computation"""
trim_len = self.get_trim_len()
cache_size = key.size(1) - query.size(1)
trim_len_cache = trim_len - (self._max_span - cache_size)
if trim_len_cache > 0:
key = key[:, trim_len_cache:, :]
value = value[:, trim_len_cache:, :]
elif trim_len_cache < 0:
# cache is too short! this happens when validation resumes
# after a lot of updates.
key = F.pad(key, [0, 0, -trim_len_cache, 0])
value = F.pad(value, [0, 0, -trim_len_cache, 0])
if trim_len > 0:
if key_pe is not None:
key_pe = key_pe[:, :, trim_len:]
return key, value, key_pe
def get_cache_size(self):
"""determine how long the cache should be"""
trim_len = self.get_trim_len()
# give a buffer of 64 steps since a span might increase
# in future updates
return min(self._max_span, self._max_span - trim_len + 64)
def get_loss(self):
"""a loss term for regularizing the span length"""
return self._max_span * self._mask.current_val.float().mean()
def get_current_max_span(self):
return self._mask.get_current_max_size()
def get_current_avg_span(self):
return self._mask.get_current_avg_size()
def clamp_param(self):
self._mask.clamp_param()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch.nn.functional as F
from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import register_criterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
@dataclass
class AdaptiveSpanCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
class AdaptiveSpanCriterion(CrossEntropyCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task, sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss here is summed, different from the adaptive span code
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
loss, aux_loss, avg_span, max_span = self.compute_loss(
model, net_output, sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
loss /= sample_size
total_loss = loss + aux_loss
sample_size = 1
logging_output = {
"loss": loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"total_loss": total_loss.data,
"avg_span": avg_span * sample_size,
"max_span": max_span * sample_size,
}
return total_loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
loss, _ = super().compute_loss(model, net_output, sample, reduce)
aux_loss = model.get_aux_loss()
avg_span = model.get_current_avg_span()
max_span = model.get_current_max_span()
return loss, aux_loss, avg_span, max_span
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
# we divide by log(2) to convert the loss from base e to base 2
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
# total loss contains the L1 norm on adaptive-span
metrics.log_scalar(
"total_loss",
total_loss_sum / sample_size / math.log(2),
sample_size,
round=3,
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules.layer_norm import LayerNorm
from .adaptive_span_attention import AdaptiveSpan
# Size notations:
# B = batch_size, H = d_model, M = block_size, L = attn_span
def _skew(X, pad_value):
"""shift every row 1 step to right"""
# X = B x M x L
B, M, L = X.size()
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
X = X.view(B, -1) # B x ML+MM+M
X = X[:, :-M] # B x ML+MM
X = X.view(B, M, M + L) # B x M x L+M
return X
def _unskew(X):
"""reverse _skew operation"""
# X = B x M x L+M
B, M, L = X.size()
L -= M
X = X.view(B, -1) # B x ML+MM
X = F.pad(X, (0, M)) # B x ML+MM+M
X = X.view(B, M, M + L + 1) # B x M x L+M+1
X = X[:, :, :L] # B x M x L
return X
class SeqAttention(nn.Module):
"""Sequential self-attention layer.
Each token will attend to its previous fixed number of steps.
Note that attention doesn't include the current step itself.
"""
def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
nn.Module.__init__(self)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model # size of a single head
self.attn_span = attn_span
self.adaptive_span = AdaptiveSpan(
attn_span=attn_span,
n_head=n_head,
adapt_span_layer=adapt_span_layer,
**kargs
)
def forward(self, query, key, value, key_pe):
# query size = B x M x H
# key, value sizes = B x (M+L) x H
key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
# compute attention from context
# B x M (dest) x (M+L) (src)
attn_cont = torch.matmul(query, key.transpose(-1, -2))
attn_cont = _unskew(attn_cont) # B x M x L
# compute the effect of position embedding
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
attn = attn_cont + attn_pos
attn = attn / math.sqrt(self.d_model) # B x M X L_pos
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
# trim attention lengths according to the learned span
attn = self.adaptive_span(attn)
attn = self.dropout(attn) # B x M X L_pos
attn_cont = _skew(attn, 0) # B x M X (L+M)
out = torch.matmul(attn_cont, value) # B x M x H
return out
def get_cache_size(self):
return self.adaptive_span.get_cache_size()
class MultiHeadSeqAttention(nn.Module):
def __init__(self, d_model, n_head, **kargs):
nn.Module.__init__(self)
assert d_model % n_head == 0
self.n_head = n_head
self.head_dim = d_model // n_head
self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
self.proj_query = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_query.weight)
self.proj_out = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_out.weight)
self.proj_val = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_val.weight)
self.proj_key = nn.Linear(d_model, d_model, bias=False)
nn.init.xavier_normal_(self.proj_key.weight)
def head_reshape(self, x):
K = self.n_head
D = self.head_dim
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
return x
def forward(self, query, key, value, key_pe):
B = query.size(0)
K = self.n_head
D = self.head_dim
M = query.size(1)
query = self.proj_query(query)
query = self.head_reshape(query)
value = self.proj_val(value)
value = self.head_reshape(value)
key = self.proj_key(key)
key = self.head_reshape(key)
out = self.attn(query, key, value, key_pe) # B_K x M x D
out = out.view(B, K, M, D) # B x K x M x D
out = out.transpose(1, 2).contiguous() # B x M x K x D
out = out.view(B, M, -1) # B x M x K_D
out = self.proj_out(out)
return out
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, d_inner, dropout, **kargs):
nn.Module.__init__(self)
self.fc1 = nn.Linear(d_model, d_inner)
self.fc2 = nn.Linear(d_inner, d_model)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, h):
h1 = F.relu(self.fc1(h))
h1 = self.dropout(h1)
h2 = self.fc2(h1)
return h2
class TransformerSeqLayer(nn.Module):
def __init__(self, d_model, **kargs):
nn.Module.__init__(self)
self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
self.norm1 = LayerNorm(d_model)
self.ff = FeedForwardLayer(d_model=d_model, **kargs)
self.norm2 = LayerNorm(d_model)
def forward(self, h, h_cache, key_pe):
# h = B x M x H
# h_cache = B x L x H
h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
attn_out = self.attn(h, h_all, h_all, key_pe)
h = self.norm1(h + attn_out) # B x M x H
if self.ff is not None:
ff_out = self.ff(h)
out = self.norm2(h + ff_out) # B x M x H
else:
out = h
return out
def get_cache_size(self):
return self.attn.attn.get_cache_size()
class TransformerSeq(nn.Module):
def __init__(
self,
vocab_size,
d_model,
n_head,
n_layer,
attn_span,
emb_dropout,
aux_loss_scaler,
adapt_span_layer,
**kargs
):
nn.Module.__init__(self)
# token embeddings
self.in_emb = nn.Embedding(vocab_size, d_model)
nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
self.out_emb = nn.Linear(d_model, vocab_size)
self.aux_loss_scaler = aux_loss_scaler
if emb_dropout > 0:
self.emb_dropout = nn.Dropout(emb_dropout)
else:
self.emb_dropout = None
# position embeddings
self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
self.layers = nn.ModuleList()
self.layers.extend(
TransformerSeqLayer(
d_model=d_model,
n_head=n_head,
attn_span=attn_span,
adapt_span_layer=adapt_span_layer,
**kargs
)
for _ in range(n_layer)
)
def forward(self, x, h_cache, target=None):
# x size = B x M
block_size = x.size(1)
h = self.in_emb(x) # B x M x H
if self.emb_dropout is not None:
h = self.emb_dropout(h)
h_cache_next = []
for l, layer in enumerate(self.layers):
cache_size = layer.attn.attn.get_cache_size()
if cache_size > block_size:
h_cache_next_l = torch.cat(
[h_cache[l][:, -cache_size + block_size :, :], h], dim=1
).detach()
else:
h_cache_next_l = h[:, -cache_size:, :].detach()
h_cache_next.append(h_cache_next_l)
h = layer(h, h_cache[l], self.key_pe) # B x M x H
if self.emb_dropout is not None:
h = self.emb_dropout(h)
out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
dummy_loss = None
return out, h_cache_next, dummy_loss
def get_aux_loss(self):
loss = 0.0
for layer in self.layers:
loss += layer.attn.attn.adaptive_span.get_loss()
return self.aux_loss_scaler * loss
def get_current_max_span(self):
max_span = 0.0
for layer in self.layers:
max_span = max(
max_span, layer.attn.attn.adaptive_span.get_current_max_span()
)
return max_span
def get_current_avg_span(self):
avg_span = 0.0
for layer in self.layers:
avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
return avg_span / len(self.layers)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment