Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
TODO (huxu): a general fairseq criterion for all your pre-defined losses.
"""
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.logging import metrics
@register_criterion("mmloss")
class MMCriterion(FairseqCriterion):
def __init__(self, task):
super().__init__(task)
# TODO (huxu): wrap forward call of loss_fn and eval_fn into task.
self.mmtask = task.mmtask
def forward(self, model, sample):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
outputs = self.mmtask(model, sample)
loss, loss_scalar, max_len, batch_size, sample_size = (
outputs["loss"],
outputs["loss_scalar"],
outputs["max_len"],
outputs["batch_size"],
outputs["sample_size"],
)
logging_output = {
"loss": loss_scalar,
"ntokens": max_len * batch_size, # dummy report.
"nsentences": batch_size, # dummy report.
"sample_size": sample_size,
}
return loss, 1, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
"""since we use NCE, our actual batch_size is 1 per GPU.
Then we take the mean of each worker."""
loss_sum = sum(log.get("loss", 0.0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
class Loss(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError
# Dummy Loss for testing.
class DummyLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class DummyK400Loss(Loss):
"""dummy k400 loss for MViT."""
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(
logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
class CrossEntropy(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
class ArgmaxCrossEntropy(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets.argmax(dim=1))
class BCE(Loss):
def __init__(self):
self.loss = nn.BCEWithLogitsLoss()
def __call__(self, logits, targets, **kwargs):
targets = targets.squeeze(0)
return self.loss(logits, targets)
class NLGLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, text_label, **kwargs):
targets = text_label[text_label != -100]
return self.loss(logits, targets)
class MSE(Loss):
def __init__(self):
self.loss = nn.MSELoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class L1(Loss):
def __init__(self):
self.loss = nn.L1Loss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class SmoothL1(Loss):
def __init__(self):
self.loss = nn.SmoothL1Loss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
softmax-based NCE loss, used by this project.
"""
import torch
from torch import nn
from .loss import Loss
class NCE(Loss):
def __init__(self):
# TODO (huxu): define temperature.
self.loss = nn.CrossEntropyLoss()
def __call__(self, align_scores, **kargs):
# note: we reuse the same shape as cls head in BERT (batch_size, 2)
# but NCE only needs one logits.
# (so we drop all weights in the second neg logits.)
align_scores = align_scores[:, :1]
# duplicate negative examples
batch_size = align_scores.size(0) // 2
pos_scores = align_scores[:batch_size]
neg_scores = align_scores[batch_size:].view(1, batch_size).repeat(
batch_size, 1)
scores = torch.cat([pos_scores, neg_scores], dim=1)
return self.loss(
scores,
torch.zeros(
(batch_size,),
dtype=torch.long,
device=align_scores.device),
)
class T2VContraLoss(Loss):
"""NCE for MM joint space, on softmax text2video matrix.
"""
def __init__(self):
# TODO (huxu): define temperature.
self.loss = nn.CrossEntropyLoss()
def __call__(self, pooled_video, pooled_text, **kargs):
batch_size = pooled_video.size(0)
logits = torch.mm(pooled_text, pooled_video.transpose(1, 0))
targets = torch.arange(
batch_size,
dtype=torch.long,
device=pooled_video.device)
return self.loss(logits, targets)
class V2TContraLoss(Loss):
"""NCE for MM joint space, with softmax on video2text matrix."""
def __init__(self):
# TODO (huxu): define temperature.
self.loss = nn.CrossEntropyLoss()
def __call__(self, pooled_video, pooled_text, **kargs):
batch_size = pooled_video.size(0)
logits = torch.mm(pooled_video, pooled_text.transpose(1, 0))
targets = torch.arange(
batch_size,
dtype=torch.long,
device=pooled_video.device)
return self.loss(logits, targets)
class MMContraLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, pooled_video, pooled_text, **kwargs):
logits_per_video = pooled_video @ pooled_text.t()
logits_per_text = pooled_text @ pooled_video.t()
targets = torch.arange(
pooled_video.size(0),
dtype=torch.long,
device=pooled_video.device)
loss_video = self.loss(logits_per_video, targets)
loss_text = self.loss(logits_per_text, targets)
return loss_video + loss_text
class MTM(Loss):
"""Combination of MFM and MLM."""
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(
self,
video_logits,
text_logits,
video_label,
text_label,
**kwargs
):
text_logits = torch.cat([
text_logits,
torch.zeros(
(text_logits.size(0), 1), device=text_logits.device)
], dim=1)
vt_logits = torch.cat([video_logits, text_logits], dim=0)
# loss for video.
video_label = torch.zeros(
(video_logits.size(0),),
dtype=torch.long,
device=video_logits.device
)
# loss for text.
text_label = text_label.reshape(-1)
labels_mask = text_label != -100
selected_text_label = text_label[labels_mask]
vt_label = torch.cat([video_label, selected_text_label], dim=0)
return self.loss(vt_logits, vt_label)
class MFMMLM(Loss):
"""Combination of MFM and MLM."""
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(
self,
video_logits,
text_logits,
video_label,
text_label,
**kwargs
):
# loss for video.
video_label = torch.zeros(
(video_logits.size(0),),
dtype=torch.long,
device=video_logits.device
)
masked_frame_loss = self.loss(video_logits, video_label)
# loss for text.
text_label = text_label.reshape(-1)
labels_mask = text_label != -100
selected_text_label = text_label[labels_mask]
masked_lm_loss = self.loss(text_logits, selected_text_label)
return masked_frame_loss + masked_lm_loss
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .mmfusion import *
from .transformermodel import *
from .mmfusionnlg import *
try:
from .fairseqmmmodel import *
except ImportError:
pass
try:
from .expmmfusion import *
except ImportError:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import (
BaseFairseqModel,
register_model,
register_model_architecture
)
@register_model("mmmodel")
class FairseqMMModel(BaseFairseqModel):
"""a fairseq wrapper of model built by `task`."""
@classmethod
def build_model(cls, args, task):
return FairseqMMModel(task.mmtask.model)
def __init__(self, mmmodel):
super().__init__()
self.mmmodel = mmmodel
def forward(self, *args, **kwargs):
return self.mmmodel(*args, **kwargs)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
keys_to_delete = []
for key in state_dict:
if key not in self.state_dict():
keys_to_delete.append(key)
for key in keys_to_delete:
print("[INFO]", key, "not used anymore.")
del state_dict[key]
# copy any newly defined parameters.
for key in self.state_dict():
if key not in state_dict:
print("[INFO] adding", key)
state_dict[key] = self.state_dict()[key]
# a dummy arch, we config the model.
@register_model_architecture("mmmodel", "mmarch")
def mmarch(args):
pass
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
try:
from transformers import AutoConfig, AutoTokenizer
except ImportError:
pass
from . import transformermodel
class MMPTModel(nn.Module):
"""An e2e wrapper of inference model.
"""
@classmethod
def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
import os
from ..utils import recursive_config
from ..tasks import Task
config = recursive_config(config)
mmtask = Task.config_task(config)
checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
mmtask.build_model(checkpoint=checkpoint_path)
# TODO(huxu): make the video encoder configurable.
from ..processors.models.s3dg import S3D
video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
video_encoder.load_state_dict(
torch.load('pretrained_models/s3d_howto100m.pth'))
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name, use_fast=config.dataset.use_fast
)
from ..processors import Aligner
aligner = Aligner(config.dataset)
return (
MMPTModel(config, mmtask.model, video_encoder),
tokenizer,
aligner
)
def __init__(self, config, model, video_encoder, **kwargs):
super().__init__()
self.max_video_len = config.dataset.max_video_len
self.video_encoder = video_encoder
self.model = model
def forward(self, video_frames, caps, cmasks, return_score=False):
bsz = video_frames.size(0)
assert bsz == 1, "only bsz=1 is supported now."
seq_len = video_frames.size(1)
video_frames = video_frames.view(-1, *video_frames.size()[2:])
vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
vfeats = vfeats['video_embedding']
vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
padding = torch.zeros(
bsz, self.max_video_len - seq_len, vfeats.size(-1))
vfeats = torch.cat([vfeats, padding], dim=1)
vmasks = torch.cat([
torch.ones((bsz, seq_len), dtype=torch.bool),
torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
],
dim=1
)
output = self.model(caps, cmasks, vfeats, vmasks)
if return_score:
output = {"score": torch.bmm(
output["pooled_video"][:, None, :],
output["pooled_text"][:, :, None]
).squeeze(-1).squeeze(-1)}
return output
class MMFusion(nn.Module):
"""a MMPT wrapper class for MMBert style models.
TODO: move isolated mask to a subclass.
"""
def __init__(self, config, **kwargs):
super().__init__()
transformer_config = AutoConfig.from_pretrained(
config.dataset.bert_name)
self.hidden_size = transformer_config.hidden_size
self.is_train = False
if config.dataset.train_path is not None:
self.is_train = True
# 0 means no iso; 1-12 means iso up to that layer.
self.num_hidden_layers = transformer_config.num_hidden_layers
self.last_iso_layer = 0
if config.dataset.num_iso_layer is not None:
self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1
if config.model.mm_encoder_cls is not None:
mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
model_config.max_video_len = config.dataset.max_video_len
# TODO: a general way to add parameter for a model.
model_config.use_seg_emb = config.model.use_seg_emb
self.mm_encoder = mm_encoder_cls.from_pretrained(
config.dataset.bert_name, config=model_config)
elif config.model.video_encoder_cls is not None\
and config.model.text_encoder_cls is not None:
video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
model_config.max_video_len = config.dataset.max_video_len
# TODO: make each model a set of config class.
if hasattr(model_config, "num_layers"):
model_config.num_layers = config.model.num_hidden_video_layers
else:
model_config.num_hidden_layers = config.model.num_hidden_video_layers
self.video_encoder = video_encoder_cls.from_pretrained(
config.dataset.bert_name, config=model_config)
# exact same NLP model from Huggingface.
text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
self.text_encoder = text_encoder_cls.from_pretrained(
config.dataset.bert_name)
else:
raise ValueError("the encoder must be either MM or two backbones.")
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
**kwargs
):
raise NotImplementedError(
"Please derive MMFusion module."
)
def _mm_on_the_fly(
self,
cmasks,
vmasks,
attention_mask
):
"""helper function for mask, seg_ids and token_type_ids."""
if attention_mask is None:
attention_mask = self._mm_attention_mask(cmasks, vmasks)
"""
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
"""
token_type_ids = torch.cat(
[
torch.zeros(
(vmasks.size(0), vmasks.size(1) + 2),
dtype=torch.long,
device=vmasks.device,
),
torch.ones(
(cmasks.size(0), cmasks.size(1) - 2),
dtype=torch.long,
device=cmasks.device,
),
],
dim=1,
)
return attention_mask, token_type_ids
def _mm_attention_mask(self, cmasks, vmasks):
assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
str(cmasks.size()),
str(vmasks.size()),
str(cmasks.size(0)),
str(vmasks.size(0)),
)
mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
if self.last_iso_layer == 0:
# hard attention mask.
return mm_mask
else:
# a gpu iso mask; 0 : num_iso_layer is isolated;
# num_iso_layer: are MM-fused.
# make an iso layer
batch_size = cmasks.size(0)
iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
iso_mm_masks = []
# hard attention mask.
iso_mask = iso_mask[:, None, :, :].repeat(
1, self.last_iso_layer, 1, 1)
iso_mm_masks.append(iso_mask)
if self.last_iso_layer < self.num_hidden_layers:
mm_mask = mm_mask[:, None, :, :].repeat(
1, self.num_hidden_layers - self.last_iso_layer, 1, 1
)
iso_mm_masks.append(mm_mask)
iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
return iso_mm_masks
def _make_iso_mask(self, batch_size, cmasks, vmasks):
cls_self_mask = torch.cat(
[
torch.ones(
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
torch.zeros(
(batch_size, cmasks.size(1) + vmasks.size(1) - 1),
dtype=torch.bool, device=cmasks.device)
], dim=1)
iso_video_mask = torch.cat(
[
# [CLS] is not used.
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=cmasks.device
),
vmasks,
# assume to be 1.
cmasks[:, 1:2],
# 2 means [CLS] + [SEP]
torch.zeros(
(batch_size, cmasks.size(1) - 2),
dtype=torch.bool,
device=cmasks.device,
),
],
dim=1,
)
iso_text_mask = torch.cat(
[
torch.zeros(
(batch_size, 2 + vmasks.size(1)),
dtype=torch.bool,
device=cmasks.device,
), # [CLS] is not used.
cmasks[:, 2:], # assume to be 1.
],
dim=1,
)
cls_self_mask = cls_self_mask[:, None, :]
iso_video_mask = iso_video_mask[:, None, :].repeat(
1, vmasks.size(1) + 1, 1)
iso_text_mask = iso_text_mask[:, None, :].repeat(
1, cmasks.size(1) - 2, 1)
return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)
def _pooling_vt_layer(
self,
layered_sequence_output,
cmasks,
vmasks
):
layer_idx = self.last_iso_layer \
if self.last_iso_layer > 0 else self.num_hidden_layers
hidden_state = layered_sequence_output[layer_idx]
# also output pooled_video and pooled_text.
batch_size = cmasks.size(0)
# pool the modality.
text_offset = vmasks.size(1) + 2 # [CLS] + [SEP]
# video tokens + [SEP]
video_outputs = hidden_state[:, 1:text_offset]
video_attention_mask = torch.cat(
[
vmasks,
torch.ones(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
],
dim=1,
)
assert video_outputs.size(1) == video_attention_mask.size(1)
pooled_video = torch.sum(
video_outputs * video_attention_mask.unsqueeze(-1), dim=1
) / video_attention_mask.sum(1, keepdim=True)
# pooled_video = torch.mean(video_outputs[0], dim=1)
# text tokens + [SEP]
text_attention_mask = cmasks[:, 2:]
text_outputs = hidden_state[:, text_offset:]
assert text_outputs.size(1) == text_attention_mask.size(1)
pooled_text = torch.sum(
text_outputs * text_attention_mask.unsqueeze(-1), dim=1
) / text_attention_mask.sum(1, keepdim=True)
return pooled_video, pooled_text
class MMFusionMFMMLM(MMFusion):
"""forward function for MFM and MLM."""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
video_label=None,
text_label=None,
**kwargs
):
output_hidden_states = False if self.is_train else True
target_vfeats, non_masked_frame_mask = None, None
if video_label is not None:
target_vfeats = vfeats.masked_select(
video_label.unsqueeze(-1)).view(
-1, vfeats.size(-1)
)
# mask video token.
vfeats[video_label] = 0.0
non_masked_frame_mask = vmasks.clone()
non_masked_frame_mask[video_label] = False
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
masked_frame_labels=video_label,
target_video_hidden_states=target_vfeats,
non_masked_frame_mask=non_masked_frame_mask,
masked_lm_labels=text_label,
output_hidden_states=output_hidden_states,
)
video_logits, text_logits = outputs[0], outputs[1]
if self.is_train: # return earlier for training.
return {
"video_logits": video_logits,
"text_logits": text_logits,
}
pooled_video, pooled_text = self._pooling_vt_layer(
outputs[2], cmasks, vmasks)
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
class MMFusionMTM(MMFusionMFMMLM):
def __init__(self, config, **kwargs):
super().__init__(config)
"""
For reproducibility:
self.mm_encoder will be initialized then discarded.
"""
from .transformermodel import MMBertForMTM
model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
model_config.max_video_len = config.dataset.max_video_len
model_config.use_seg_emb = config.model.use_seg_emb
self.mm_encoder = MMBertForMTM.from_pretrained(
config.dataset.bert_name, config=model_config)
class MMFusionShare(MMFusion):
"""A retrival wrapper using mm_encoder as both video/text backbone.
TODO: move formally.
"""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
video_label=None,
text_label=None,
output_hidden_states=False,
**kwargs
):
pooled_video = self.forward_video(
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states
)
pooled_text = self.forward_text(
caps,
cmasks,
output_hidden_states
)
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
def forward_video(
self,
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = caps[:, :2]
attention_mask = torch.cat([
cmasks[:, :1],
vmasks,
cmasks[:, 1:2]
], dim=1)
token_type_ids = torch.zeros(
(vmasks.size(0), vmasks.size(1) + 2),
dtype=torch.long,
device=vmasks.device)
outputs = self.mm_encoder(
input_ids=input_ids,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
video_outputs = outputs[0]
if output_hidden_states:
return video_outputs
batch_size = cmasks.size(0)
video_attention_mask = torch.cat(
[
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
vmasks,
torch.ones(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
],
dim=1,
)
assert video_outputs.size(1) == video_attention_mask.size(1)
video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
/ video_attention_mask.sum(1, keepdim=True)
pooled_video = torch.bmm(
video_outputs.transpose(2, 1),
video_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_video # video_outputs
def forward_text(
self,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = torch.cat([
caps[:, :1], caps[:, 2:],
], dim=1)
attention_mask = torch.cat([
cmasks[:, :1],
cmasks[:, 2:]
], dim=1)
token_type_ids = torch.cat([
torch.zeros(
(cmasks.size(0), 1),
dtype=torch.long,
device=cmasks.device),
torch.ones(
(cmasks.size(0), cmasks.size(1) - 2),
dtype=torch.long,
device=cmasks.device)
], dim=1)
outputs = self.mm_encoder(
input_ids=input_ids,
input_video_embeds=None,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
text_outputs = outputs[0]
if output_hidden_states:
return text_outputs
batch_size = caps.size(0)
# text tokens + [SEP]
text_attention_mask = torch.cat([
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
cmasks[:, 2:]
], dim=1)
assert text_outputs.size(1) == text_attention_mask.size(1)
text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
/ text_attention_mask.sum(1, keepdim=True)
pooled_text = torch.bmm(
text_outputs.transpose(2, 1),
text_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_text # text_outputs
class MMFusionSeparate(MMFusionShare):
def forward_video(
self,
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = caps[:, :2]
attention_mask = torch.cat([
cmasks[:, :1],
vmasks,
cmasks[:, 1:2]
], dim=1)
token_type_ids = torch.zeros(
(vmasks.size(0), vmasks.size(1) + 2),
dtype=torch.long,
device=vmasks.device)
outputs = self.video_encoder(
input_ids=input_ids,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
video_outputs = outputs[0]
if output_hidden_states:
return video_outputs
batch_size = cmasks.size(0)
video_attention_mask = torch.cat(
[
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
vmasks,
torch.ones(
(batch_size, 1), dtype=torch.bool, device=vmasks.device),
],
dim=1,
)
assert video_outputs.size(1) == video_attention_mask.size(1)
video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
/ video_attention_mask.sum(1, keepdim=True)
pooled_video = torch.bmm(
video_outputs.transpose(2, 1),
video_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_video # video_outputs
def forward_text(
self,
caps,
cmasks,
output_hidden_states=False,
**kwargs
):
input_ids = torch.cat([
caps[:, :1], caps[:, 2:],
], dim=1)
attention_mask = torch.cat([
cmasks[:, :1],
cmasks[:, 2:]
], dim=1)
# different from sharing, we use all-0 type.
token_type_ids = torch.zeros(
(cmasks.size(0), cmasks.size(1) - 1),
dtype=torch.long,
device=cmasks.device)
outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True
)
text_outputs = outputs[0]
if output_hidden_states:
return text_outputs
batch_size = caps.size(0)
# text tokens + [SEP]
text_attention_mask = torch.cat([
torch.zeros(
(batch_size, 1), dtype=torch.bool, device=cmasks.device),
cmasks[:, 2:]
], dim=1)
assert text_outputs.size(1) == text_attention_mask.size(1)
text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
/ text_attention_mask.sum(1, keepdim=True)
pooled_text = torch.bmm(
text_outputs.transpose(2, 1),
text_attention_mask.unsqueeze(2)
).squeeze(-1)
return pooled_text # text_outputs
class MMFusionJoint(MMFusion):
"""fine-tuning wrapper for retrival task."""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
video_label=None,
text_label=None,
**kwargs
):
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
output_hidden_states = True
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
separate_forward_split = (
None if self.is_train else vmasks.size(1) + 2
) # [CLS] + [SEP]
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
separate_forward_split=separate_forward_split,
)
pooled_video, pooled_text = self._pooling_vt_layer(
outputs[2], cmasks, vmasks)
return {"pooled_video": pooled_video, "pooled_text": pooled_text}
class MMFusionActionSegmentation(MMFusion):
"""Fine-tuning wrapper for action segmentation.
TODO: rename this for VLM.
"""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.view(-1, caps.size(-1))
cmasks = cmasks.view(-1, cmasks.size(-1))
vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
vmasks = vmasks.view(-1, vmasks.size(-1))
# this may not cover all shapes of attention_mask.
attention_mask = attention_mask.view(
-1, attention_mask.size(2), attention_mask.size(3)) \
if attention_mask is not None else None
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
output_hidden_states = True
# video forwarding, text is dummy; never use attention_mask.
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
logits = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
)
return {"logits": logits[0][:, 1:vmasks.size(1)+1]}
class MMFusionActionLocalization(MMFusion):
"""fine-tuning model for retrival task."""
def __init__(self, config, **kwargs):
super().__init__(config)
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.squeeze(0)
cmasks = cmasks.squeeze(0)
vfeats = vfeats.squeeze(0)
vmasks = vmasks.squeeze(0)
attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
output_hidden_states = True
# a len1 dummy video token.
dummy_vfeats = torch.zeros(
(caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
dummy_vmasks = torch.ones(
(caps.size(0), 1), dtype=torch.bool,
device=vfeats.device)
dummy_caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).to(caps.device).repeat(vfeats.size(0), 1)
dummy_cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).to(caps.device).repeat(vfeats.size(0), 1)
# video forwarding, text is dummy; never use attention_mask.
attention_mask, token_type_ids = self._mm_on_the_fly(
dummy_cmasks, vmasks, None)
outputs = self.mm_encoder(
input_ids=dummy_caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
)
layer_idx = self.last_iso_layer \
if self.last_iso_layer > 0 else self.num_hidden_layers
video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
vmasks.unsqueeze(-1)
).view(-1, self.hidden_size)
# text forwarding, video is dummy
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, dummy_vmasks, None)
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=dummy_vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=output_hidden_states,
)
_, pooled_text = self._pooling_vt_layer(
outputs[2], cmasks, dummy_vmasks)
# this line is not right.
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
return {"logits": logits}
# --------------- MMFusionSeparate for end tasks ---------------
class MMFusionSeparateActionSegmentation(MMFusionSeparate):
"""Fine-tuning wrapper for action segmentation."""
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask=None,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.view(-1, caps.size(-1))
cmasks = cmasks.view(-1, cmasks.size(-1))
vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
vmasks = vmasks.view(-1, vmasks.size(-1))
logits = self.forward_video(
vfeats,
vmasks,
caps,
cmasks,
output_hidden_states=True
)
return {"logits": logits[:, 1:vmasks.size(1)+1]}
class MMFusionSeparateActionLocalization(MMFusionSeparate):
def __init__(self, config, **kwargs):
super().__init__(config)
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.squeeze(0)
cmasks = cmasks.squeeze(0)
vfeats = vfeats.squeeze(0)
vmasks = vmasks.squeeze(0)
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
dummy_caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).to(caps.device).repeat(vfeats.size(0), 1)
dummy_cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).to(caps.device).repeat(vfeats.size(0), 1)
outputs = self.forward_video(
vfeats,
vmasks,
dummy_caps,
dummy_cmasks,
output_hidden_states=True
)
video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
vmasks.unsqueeze(-1)
).view(-1, self.hidden_size)
pooled_text = self.forward_text(
caps,
cmasks,
output_hidden_states=False
)
# this line is not right.
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
return {"logits": logits}
class MMFusionShareActionLocalization(MMFusionShare):
def __init__(self, config, **kwargs):
super().__init__(config)
tokenizer = AutoTokenizer.from_pretrained(
config.dataset.bert_name)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
**kwargs
):
# ActionLocalization assume of batch_size=1, squeeze it.
caps = caps.squeeze(0)
cmasks = cmasks.squeeze(0)
vfeats = vfeats.squeeze(0)
vmasks = vmasks.squeeze(0)
# TODO (huxu): other ways to do negative examples; move the following
# into your criterion forward.
dummy_caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).to(caps.device).repeat(vfeats.size(0), 1)
dummy_cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).to(caps.device).repeat(vfeats.size(0), 1)
outputs = self.forward_video(
vfeats,
vmasks,
dummy_caps,
dummy_cmasks,
output_hidden_states=True
)
video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
vmasks.unsqueeze(-1)
).view(-1, self.hidden_size)
pooled_text = self.forward_text(
caps,
cmasks,
output_hidden_states=False
)
# this line is not right.
logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
return {"logits": logits}
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch.nn import functional as F
from typing import Optional, Iterable
try:
from transformers import BertPreTrainedModel
from transformers.modeling_bert import BertOnlyMLMHead
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation_utils import (
BeamHypotheses,
top_k_top_p_filtering
)
except ImportError:
pass
from .mmfusion import MMFusion
from .transformermodel import MMBertModel
from ..modules import VideoTokenMLP
class MMFusionNLG(MMFusion):
def __init__(self, config, **kwargs):
super().__init__(config)
if config.model.max_decode_length is not None:
self.max_length = min(
config.model.max_decode_length,
config.dataset.max_len - config.dataset.max_video_len - 3
)
else:
self.max_length = \
config.dataset.max_len - config.dataset.max_video_len - 3
self.gen_param = config.gen_param if config.gen_param is not None \
else {}
def forward(
self,
caps,
cmasks,
vfeats,
vmasks,
attention_mask,
video_label=None,
text_label=None,
**kwargs
):
"""use pre-trained LM header for generation."""
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
outputs = self.mm_encoder(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
masked_lm_labels=text_label,
)
return {"logits": outputs[0]}
@torch.no_grad()
def generate(
self,
caps, cmasks, vfeats, vmasks,
attention_mask=None,
bos_token_id=None,
eos_token_id=None,
**kwargs
):
# a simplified interface from
# https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html#GenerationMixin.generate
# caps now only have
# [CLS], [SEP] (for video) and [CLS] (as bos_token)
assert caps.size(1) == 3
attention_mask, token_type_ids = self._mm_on_the_fly(
cmasks, vmasks, attention_mask)
output = self.mm_encoder.generate(
input_ids=caps,
input_video_embeds=vfeats,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
max_length=self.max_length,
**self.gen_param
)
return output
class MMBertForNLG(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = MMBertModel(config)
self.videomlp = VideoTokenMLP(config)
# we do not use `BertGenerationOnlyLMHead`
# because we can reuse pretraining.
self.cls = BertOnlyMLMHead(config)
self.hidden_size = config.hidden_size
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(
self,
input_ids=None,
input_video_embeds=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# similar to MMBertForMFMMLM without MFM.
video_tokens = self.videomlp(input_video_embeds)
outputs = self.bert(
input_ids,
video_tokens,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = None
if masked_lm_labels is not None:
text_offset = input_video_embeds.size(1) + 1 # [CLS]
# recover caps format: [CLS] [SEP] text [SEP]
text_sequence_output = torch.cat(
[sequence_output[:, :1], sequence_output[:, text_offset:]],
dim=1
)
# only compute select tokens to training to speed up.
hidden_size = text_sequence_output.size(-1)
# masked_lm_labels = masked_lm_labels.reshape(-1)
labels_mask = masked_lm_labels != -100
selected_text_output = text_sequence_output.masked_select(
labels_mask.unsqueeze(-1)
).view(-1, hidden_size)
prediction_scores = self.cls(selected_text_output)
if not return_dict:
output = (
prediction_scores,
) + outputs[2:]
return output
# for generation.
text_offset = input_video_embeds.size(1) + 2 # [CLS]
text_sequence_output = sequence_output[:, text_offset:]
prediction_scores = self.cls(text_sequence_output)
return CausalLMOutput(
loss=None,
logits=prediction_scores,
)
def prepare_inputs_for_generation(
self,
input_ids,
input_video_embeds,
attention_mask=None,
token_type_ids=None,
**model_kwargs
):
# must return a dictionary.
seq_len = input_ids.size(1) + input_video_embeds.size(1)
if attention_mask is not None:
if len(attention_mask.size()) == 4:
attention_mask = attention_mask[:, :, :seq_len, :seq_len]
elif len(attention_mask.size()) == 3:
attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
attention_mask = attention_mask[:, :seq_len]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, :seq_len]
return {
"input_ids": input_ids,
"input_video_embeds": input_video_embeds,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
indicated are the default values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes
it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only
decoder_start_token_id is passed as the first token to the decoder.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
min_length (:obj:`int`, `optional`, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
num_beams (:obj:`int`, `optional`, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (:obj:`float`, `optional`, defaults tp 1.0):
The value used to module the next token probabilities.
top_k (:obj:`int`, `optional`, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (:obj:`float`, `optional`, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or
higher are kept for generation.
repetition_penalty (:obj:`float`, `optional`, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
bos_token_id (:obj:`int`, `optional`):
The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty.
Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in
order to encourage the model to produce longer sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(:obj:`List[int]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
tokens that are not masked, and 0 for masked tokens.
If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`:
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog'
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
"""
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
)
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
else:
batch_size = 1
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
isinstance(bos_token_id, int) and bos_token_id >= 0
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
), "`no_repeat_ngram_size` should be a positive integer."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a strictly positive integer."
assert (
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = torch.full(
(batch_size, 1),
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if do_sample is False:
if num_beams == 1:
# no_beam_search greedy generation conditions
assert (
num_return_sequences == 1
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
else:
# beam_search greedy generation conditions
assert (
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_id is not None:
print(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_id
# vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
self.config.is_encoder_decoder
and hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
else:
raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
effective_batch_size = batch_size * num_return_sequences
effective_batch_mult = num_return_sequences
else:
effective_batch_size = batch_size
effective_batch_mult = 1
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
# see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
# TODO: make this a call-back function.
# input_ids=caps,
# input_video_embeds=vfeats,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids,
input_video_embeds = model_kwargs.pop("input_video_embeds", None)
token_type_ids = model_kwargs.pop("token_type_ids", None)
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len)
input_video_embeds_len, input_video_embeds_hidden = input_video_embeds.size(1), input_video_embeds.size(2)
input_video_embeds = input_video_embeds.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_video_embeds_len, input_video_embeds_hidden)
attention_mask_from_len, attention_mask_to_len = attention_mask.size(1), attention_mask.size(2)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, attention_mask_from_len, attention_mask_to_len
)
token_type_ids_len = token_type_ids.size(1)
token_type_ids = token_type_ids.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, token_type_ids_len
)
# contiguous ...
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
input_video_embeds = input_video_embeds.contiguous().view(
effective_batch_size * num_beams, input_video_embeds_len, input_video_embeds_hidden)
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, attention_mask_from_len, attention_mask_to_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
token_type_ids = token_type_ids.contiguous().view(
effective_batch_size * num_beams, token_type_ids_len
)
model_kwargs["input_video_embeds"] = input_video_embeds
model_kwargs["token_type_ids"] = token_type_ids
if self.config.is_encoder_decoder:
device = next(self.parameters()).device
if decoder_input_ids is not None:
# give initial decoder input ids
input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device)
else:
# create empty decoder input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
device=device,
)
cur_len = input_ids.shape[-1]
assert (
batch_size == encoder_outputs.last_hidden_state.shape[0]
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
torch.arange(batch_size)
.view(-1, 1)
.repeat(1, num_beams * effective_batch_mult)
.view(-1)
.to(input_ids.device)
)
# expand encoder_outputs
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)
# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
else:
cur_len = input_ids.shape[-1]
assert (
cur_len < max_length
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
if num_beams > 1:
output = self._generate_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
)
else:
output = self._generate_no_beam_search(
input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
attention_mask=attention_mask,
use_cache=use_cache,
model_kwargs=model_kwargs,
)
return output
def _generate_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
num_return_sequences,
length_penalty,
num_beams,
vocab_size,
attention_mask,
use_cache,
model_kwargs,
):
"""Generate sequences for each example with beam search."""
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
scores = self.postprocess_next_token_scores(
scores=scores,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=num_beams,
)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size)
)
if do_sample:
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Temperature
if temperature != 1.0:
_scores = _scores / temperature
# Top-p/top-k filtering
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
probs = F.softmax(_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else:
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence, add a pad token
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content, this will get added to next_batch_beam
next_sent_beam = []
# next tokens for this sentence
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(),
beam_token_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams:
break
# Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
)
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
# stop when we are done with each sentence
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and update current length
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
cur_len = cur_len + 1
# re-order internal states
if past is not None:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
# (huxu): move out since we trim attention_mask by ourselves.
# if self.config.is_encoder_decoder is False:
# attention_mask = torch.cat(
# [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
# )
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx in range(batch_size):
if done[batch_idx]:
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_id is not None and all(
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx],
beam_scores.view(batch_size, num_beams)[batch_idx],
)
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(num_beams):
effective_beam_id = batch_idx * num_beams + beam_id
final_score = beam_scores[effective_beam_id].item()
final_tokens = input_ids[effective_beam_id]
generated_hyps[batch_idx].add(final_tokens, final_score)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
# select the best hypotheses
sent_lengths = input_ids.new(output_batch_size)
best = []
# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
for j in range(output_num_return_sequences_per_batch):
effective_batch_idx = output_num_return_sequences_per_batch * i + j
best_hyp = sorted_hyps.pop()[1]
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return decoded
def _generate_no_beam_search(
self,
input_ids,
cur_len,
max_length,
min_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
bad_words_ids,
pad_token_id,
eos_token_id,
batch_size,
attention_mask,
use_cache,
model_kwargs,
):
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
input_ids=input_ids,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
cur_len=cur_len,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
batch_size=batch_size,
num_beams=1,
)
# if model has past, then set the past variable to speed up decoding
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(next_token_logscores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1)
# print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id])
# update generations and finished sentences
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
# add token and increase length by one
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
# extend attention_mask for new generated input if only decoder
# if self.config.is_encoder_decoder is False:
# attention_mask = torch.cat(
# [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
# )
return input_ids
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
try:
from transformers.modeling_bert import (
BertPreTrainedModel,
BertModel,
BertEncoder,
BertPredictionHeadTransform,
)
except ImportError:
pass
from ..modules import VideoTokenMLP, MMBertEmbeddings
# --------------- fine-tuning models ---------------
class MMBertForJoint(BertPreTrainedModel):
"""A BertModel with isolated attention mask to separate modality."""
def __init__(self, config):
super().__init__(config)
self.videomlp = VideoTokenMLP(config)
self.bert = MMBertModel(config)
self.init_weights()
def forward(
self,
input_ids=None,
input_video_embeds=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
separate_forward_split=None,
):
return_dict = (
return_dict if return_dict is not None
else self.config.use_return_dict
)
video_tokens = self.videomlp(input_video_embeds)
outputs = self.bert(
input_ids,
video_tokens,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
separate_forward_split=separate_forward_split,
)
return outputs
class MMBertForTokenClassification(BertPreTrainedModel):
"""A BertModel similar to MMJointUni, with extra wrapper layer
to be fine-tuned from other pretrained MMFusion model."""
def __init__(self, config):
super().__init__(config)
self.videomlp = VideoTokenMLP(config)
self.bert = MMBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# TODO(huxu): 779 is the number of classes for COIN: move to config?
self.classifier = nn.Linear(config.hidden_size, 779)
self.init_weights()
def forward(
self,
input_ids=None,
input_video_embeds=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
separate_forward_split=None,
):
return_dict = (
return_dict if return_dict is not None
else self.config.use_return_dict
)
video_tokens = self.videomlp(input_video_embeds)
outputs = self.bert(
input_ids,
video_tokens,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
separate_forward_split=separate_forward_split,
)
return (self.classifier(outputs[0]),)
# ------------ pre-training models ----------------
class MMBertForEncoder(BertPreTrainedModel):
"""A BertModel for Contrastive Learning."""
def __init__(self, config):
super().__init__(config)
self.videomlp = VideoTokenMLP(config)
self.bert = MMBertModel(config)
self.init_weights()
def forward(
self,
input_ids=None,
input_video_embeds=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = (
return_dict if return_dict is not None
else self.config.use_return_dict
)
if input_video_embeds is not None:
video_tokens = self.videomlp(input_video_embeds)
else:
video_tokens = None
outputs = self.bert(
input_ids,
video_tokens,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs
class MMBertForMFMMLM(BertPreTrainedModel):
"""A BertModel with shared prediction head on MFM-MLM."""
def __init__(self, config):
super().__init__(config)
self.videomlp = VideoTokenMLP(config)
self.bert = MMBertModel(config)
self.cls = MFMMLMHead(config)
self.hidden_size = config.hidden_size
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(
self,
input_ids=None,
input_video_embeds=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_frame_labels=None,
target_video_hidden_states=None,
non_masked_frame_mask=None,
masked_lm_labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = (
return_dict if return_dict is not None
else self.config.use_return_dict
)
if input_video_embeds is not None:
video_tokens = self.videomlp(input_video_embeds)
else:
video_tokens = None
if target_video_hidden_states is not None:
target_video_hidden_states = self.videomlp(
target_video_hidden_states)
non_masked_frame_hidden_states = video_tokens.masked_select(
non_masked_frame_mask.unsqueeze(-1)
).view(-1, self.hidden_size)
outputs = self.bert(
input_ids,
video_tokens,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
mfm_scores, prediction_scores = None, None
if masked_frame_labels is not None and masked_lm_labels is not None:
# split the sequence.
text_offset = masked_frame_labels.size(1) + 1 # [CLS]
video_sequence_output = sequence_output[
:, 1:text_offset
] # remove [SEP] as not in video_label.
text_sequence_output = torch.cat(
[sequence_output[:, :1], sequence_output[:, text_offset:]],
dim=1
)
hidden_size = video_sequence_output.size(-1)
selected_video_output = video_sequence_output.masked_select(
masked_frame_labels.unsqueeze(-1)
).view(-1, hidden_size)
# only compute select tokens to training to speed up.
hidden_size = text_sequence_output.size(-1)
# masked_lm_labels = masked_lm_labels.reshape(-1)
labels_mask = masked_lm_labels != -100
selected_text_output = text_sequence_output.masked_select(
labels_mask.unsqueeze(-1)
).view(-1, hidden_size)
mfm_scores, prediction_scores = self.cls(
selected_video_output,
target_video_hidden_states,
non_masked_frame_hidden_states,
selected_text_output,
)
output = (
mfm_scores,
prediction_scores,
) + outputs
return output
class BertMFMMLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(
self,
video_hidden_states=None,
target_video_hidden_states=None,
non_masked_frame_hidden_states=None,
text_hidden_states=None,
):
video_logits, text_logits = None, None
if video_hidden_states is not None:
video_hidden_states = self.transform(video_hidden_states)
non_masked_frame_logits = torch.mm(
video_hidden_states,
non_masked_frame_hidden_states.transpose(1, 0)
)
masked_frame_logits = torch.bmm(
video_hidden_states.unsqueeze(1),
target_video_hidden_states.unsqueeze(-1),
).squeeze(-1)
video_logits = torch.cat(
[masked_frame_logits, non_masked_frame_logits], dim=1
)
if text_hidden_states is not None:
text_hidden_states = self.transform(text_hidden_states)
text_logits = self.decoder(text_hidden_states)
return video_logits, text_logits
class MFMMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertMFMMLMPredictionHead(config)
def forward(
self,
video_hidden_states=None,
target_video_hidden_states=None,
non_masked_frame_hidden_states=None,
text_hidden_states=None,
):
video_logits, text_logits = self.predictions(
video_hidden_states,
target_video_hidden_states,
non_masked_frame_hidden_states,
text_hidden_states,
)
return video_logits, text_logits
class MMBertForMTM(MMBertForMFMMLM):
def __init__(self, config):
BertPreTrainedModel.__init__(self, config)
self.videomlp = VideoTokenMLP(config)
self.bert = MMBertModel(config)
self.cls = MTMHead(config)
self.hidden_size = config.hidden_size
self.init_weights()
class BertMTMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
self.decoder = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
def forward(
self,
video_hidden_states=None,
target_video_hidden_states=None,
non_masked_frame_hidden_states=None,
text_hidden_states=None,
):
non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
video_logits, text_logits = None, None
if video_hidden_states is not None:
video_hidden_states = self.transform(video_hidden_states)
masked_frame_logits = torch.bmm(
video_hidden_states.unsqueeze(1),
target_video_hidden_states.unsqueeze(-1),
).squeeze(-1)
non_masked_frame_logits = torch.mm(
video_hidden_states,
non_masked_frame_hidden_states
)
video_on_vocab_logits = self.decoder(video_hidden_states)
video_logits = torch.cat([
masked_frame_logits,
non_masked_frame_logits,
video_on_vocab_logits], dim=1)
if text_hidden_states is not None:
text_hidden_states = self.transform(text_hidden_states)
# text first so label does not need to be shifted.
text_on_vocab_logits = self.decoder(text_hidden_states)
text_on_video_logits = torch.mm(
text_hidden_states,
non_masked_frame_hidden_states
)
text_logits = torch.cat([
text_on_vocab_logits,
text_on_video_logits
], dim=1)
return video_logits, text_logits
class MTMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertMTMPredictionHead(config)
def forward(
self,
video_hidden_states=None,
target_video_hidden_states=None,
non_masked_frame_hidden_states=None,
text_hidden_states=None,
):
video_logits, text_logits = self.predictions(
video_hidden_states,
target_video_hidden_states,
non_masked_frame_hidden_states,
text_hidden_states,
)
return video_logits, text_logits
class MMBertModel(BertModel):
"""MMBertModel has MMBertEmbedding to support video tokens."""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
# overwrite embedding
self.embeddings = MMBertEmbeddings(config)
self.encoder = MultiLayerAttentionMaskBertEncoder(config)
self.init_weights()
def forward(
self,
input_ids=None,
input_video_embeds=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
separate_forward_split=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None
else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids "
"and inputs_embeds at the same time"
)
elif input_ids is not None:
if input_video_embeds is not None:
input_shape = (
input_ids.size(0),
input_ids.size(1) + input_video_embeds.size(1),
)
else:
input_shape = (
input_ids.size(0),
input_ids.size(1),
)
elif inputs_embeds is not None:
if input_video_embeds is not None:
input_shape = (
inputs_embeds.size(0),
inputs_embeds.size(1) + input_video_embeds.size(1),
)
else:
input_shape = (
input_ids.size(0),
input_ids.size(1),
)
else:
raise ValueError(
"You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None \
else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions
# [batch_size, from_seq_length, to_seq_length]
# ourselves in which case
# we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = \
self.get_extended_attention_mask(
attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to
# [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (
encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or
# [num_hidden_layers x num_heads]
# and head_mask is converted to shape
# [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(
head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids,
input_video_embeds,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
)
if separate_forward_split is not None:
split_embedding_output = \
embedding_output[:, :separate_forward_split]
split_extended_attention_mask = extended_attention_mask[
:, :, :, :separate_forward_split, :separate_forward_split
]
split_encoder_outputs = self.encoder(
split_embedding_output,
attention_mask=split_extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
assert (
len(split_encoder_outputs) <= 2
), "we do not support merge on attention for now."
encoder_outputs = []
encoder_outputs.append([split_encoder_outputs[0]])
if len(split_encoder_outputs) == 2:
encoder_outputs.append([])
for _all_hidden_states in split_encoder_outputs[1]:
encoder_outputs[-1].append([_all_hidden_states])
split_embedding_output = \
embedding_output[:, separate_forward_split:]
split_extended_attention_mask = extended_attention_mask[
:, :, :, separate_forward_split:, separate_forward_split:
]
split_encoder_outputs = self.encoder(
split_embedding_output,
attention_mask=split_extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
assert (
len(split_encoder_outputs) <= 2
), "we do not support merge on attention for now."
encoder_outputs[0].append(split_encoder_outputs[0])
encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
if len(split_encoder_outputs) == 2:
for layer_idx, _all_hidden_states in enumerate(
split_encoder_outputs[1]
):
encoder_outputs[1][layer_idx].append(_all_hidden_states)
encoder_outputs[1][layer_idx] = torch.cat(
encoder_outputs[1][layer_idx], dim=1
)
encoder_outputs = tuple(encoder_outputs)
else:
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
return (sequence_output, pooled_output) + encoder_outputs[1:]
def get_extended_attention_mask(self, attention_mask, input_shape, device):
"""This is borrowed from `modeling_utils.py` with the support of
multi-layer attention masks.
The second dim is expected to be number of layers.
See `MMAttentionMaskProcessor`.
Makes broadcastable attention and causal masks so that future
and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to,
zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, \
with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions
# [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable
# to all heads.
if attention_mask.dim() == 4:
extended_attention_mask = attention_mask[:, :, None, :, :]
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) \
* -10000.0
return extended_attention_mask
else:
return super().get_extended_attention_mask(
attention_mask, input_shape, device
)
class MultiLayerAttentionMaskBertEncoder(BertEncoder):
"""extend BertEncoder with the capability of
multiple layers of attention mask."""
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_attention_mask = (
attention_mask[:, i, :, :, :]
if attention_mask.dim() == 5
else attention_mask
)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
layer_attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
layer_attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return tuple(
v
for v in [hidden_states, all_hidden_states, all_attentions]
if v is not None
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .mm import *
try:
from .expmm import *
except ImportError:
pass
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
try:
from transformers.modeling_bert import (
BertEmbeddings,
ACT2FN,
)
except ImportError:
pass
class VideoTokenMLP(nn.Module):
def __init__(self, config):
super().__init__()
input_dim = config.input_dim if hasattr(config, "input_dim") else 512
self.linear1 = nn.Linear(input_dim, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, hidden_states):
hidden_states = self.linear1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.linear2(hidden_states)
return hidden_states
class MMBertEmbeddings(BertEmbeddings):
def __init__(self, config):
super().__init__(config)
self.max_video_len = config.max_video_len
if hasattr(config, "use_seg_emb") and config.use_seg_emb:
"""the original VLM paper uses seg_embeddings for temporal space.
although not used it changed the randomness of initialization.
we keep it for reproducibility.
"""
self.seg_embeddings = nn.Embedding(256, config.hidden_size)
def forward(
self,
input_ids,
input_video_embeds,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
):
input_tensor = input_ids if input_ids is not None else inputs_embeds
if input_video_embeds is not None:
input_shape = (
input_tensor.size(0),
input_tensor.size(1) + input_video_embeds.size(1),
)
else:
input_shape = (input_tensor.size(0), input_tensor.size(1))
if position_ids is None:
"""
Auto skip position embeddings for text only case.
use cases:
(1) action localization and segmentation:
feed in len-1 dummy video token needs text part to
skip input_video_embeds.size(1) for the right
position_ids for video [SEP] and rest text tokens.
(2) MMFusionShare for two forward passings:
in `forward_text`: input_video_embeds is None.
need to skip video [SEP] token.
# video_len + 1: [CLS] + video_embed
# self.max_video_len + 1: [SEP] for video.
# self.max_video_len + 2: [SEP] for video.
# self.max_video_len + input_ids.size(1): rest for text.
"""
if input_video_embeds is not None:
video_len = input_video_embeds.size(1)
starting_offset = self.max_video_len + 1 # video [SEP]
ending_offset = self.max_video_len + input_ids.size(1)
else:
video_len = 0
starting_offset = self.max_video_len + 2 # first text token.
ending_offset = self.max_video_len + input_ids.size(1) + 1
position_ids = torch.cat([
self.position_ids[:, :video_len + 1],
self.position_ids[:, starting_offset:ending_offset]
], dim=1)
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=self.position_ids.device
)
"""
the format of input_ids is [CLS] [SEP] caption [SEP] padding.
the goal is to build [CLS] video tokens [SEP] caption [SEP] .
"""
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if input_video_embeds is not None:
inputs_mm_embeds = torch.cat([
inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
], dim=1)
else:
# text only for `MMFusionShare`.
inputs_mm_embeds = inputs_embeds
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_mm_embeds + position_embeddings
embeddings += token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class AlignHead(nn.Module):
"""this will load pre-trained weights for NSP, which is desirable."""
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, dropout_pooled_output):
logits = self.seq_relationship(dropout_pooled_output)
return logits
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import pickle
import time
try:
import faiss
except ImportError:
pass
from collections import defaultdict
from ..utils import get_local_rank, print_on_rank0
class VectorRetriever(object):
"""
How2 Video Retriver.
Reference usage of FAISS:
https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
"""
def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
if db_type == "flatl2":
quantizer = faiss.IndexFlatL2(hidden_size) # the other index
self.db = faiss.IndexIVFFlat(
quantizer, hidden_size, cent, faiss.METRIC_L2)
elif db_type == "pq":
self.db = faiss.index_factory(
hidden_size, f"IVF{cent}_HNSW32,PQ32"
)
else:
raise ValueError("unknown type of db", db_type)
self.train_thres = cent * examples_per_cent_to_train
self.train_cache = []
self.train_len = 0
self.videoid_to_vectoridx = {}
self.vectoridx_to_videoid = None
self.make_direct_maps_done = False
def make_direct_maps(self):
faiss.downcast_index(self.db).make_direct_map()
def __len__(self):
return self.db.ntotal
def save(self, out_dir):
faiss.write_index(
self.db,
os.path.join(out_dir, "faiss_idx")
)
with open(
os.path.join(
out_dir, "videoid_to_vectoridx.pkl"),
"wb") as fw:
pickle.dump(
self.videoid_to_vectoridx, fw,
protocol=pickle.HIGHEST_PROTOCOL
)
def load(self, out_dir):
fn = os.path.join(out_dir, "faiss_idx")
self.db = faiss.read_index(fn)
with open(
os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
self.videoid_to_vectoridx = pickle.load(fr)
def add(self, hidden_states, video_ids, last=False):
assert len(hidden_states) == len(video_ids), "{}, {}".format(
str(len(hidden_states)), str(len(video_ids)))
assert len(hidden_states.shape) == 2
assert hidden_states.dtype == np.float32
valid_idx = []
for idx, video_id in enumerate(video_ids):
if video_id not in self.videoid_to_vectoridx:
valid_idx.append(idx)
self.videoid_to_vectoridx[video_id] = \
len(self.videoid_to_vectoridx)
hidden_states = hidden_states[valid_idx]
if not self.db.is_trained:
self.train_cache.append(hidden_states)
self.train_len += hidden_states.shape[0]
if self.train_len < self.train_thres:
return
self.finalize_training()
else:
self.db.add(hidden_states)
def finalize_training(self):
hidden_states = np.concatenate(self.train_cache, axis=0)
del self.train_cache
local_rank = get_local_rank()
if local_rank == 0:
start = time.time()
print("training db on", self.train_thres, "/", self.train_len)
self.db.train(hidden_states[:self.train_thres])
if local_rank == 0:
print("training db for", time.time() - start)
self.db.add(hidden_states)
def search(
self,
query_hidden_states,
orig_dist,
):
if len(self.videoid_to_vectoridx) != self.db.ntotal:
raise ValueError(
"cannot search: size mismatch in-between index and db",
len(self.videoid_to_vectoridx),
self.db.ntotal
)
if self.vectoridx_to_videoid is None:
self.vectoridx_to_videoid = {
self.videoid_to_vectoridx[videoid]: videoid
for videoid in self.videoid_to_vectoridx
}
assert len(self.vectoridx_to_videoid) \
== len(self.videoid_to_vectoridx)
# MultilingualFaissDataset uses the following; not sure the purpose.
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
queried_dist, index = self.db.search(query_hidden_states, 1)
queried_dist, index = queried_dist[:, 0], index[:, 0]
outputs = np.array(
[self.vectoridx_to_videoid[_index]
if _index != -1 else (-1, -1, -1) for _index in index],
dtype=np.int32)
outputs[queried_dist <= orig_dist] = -1
return outputs
def search_by_video_ids(
self,
video_ids,
retri_factor
):
if len(self.videoid_to_vectoridx) != self.db.ntotal:
raise ValueError(
len(self.videoid_to_vectoridx),
self.db.ntotal
)
if not self.make_direct_maps_done:
self.make_direct_maps()
if self.vectoridx_to_videoid is None:
self.vectoridx_to_videoid = {
self.videoid_to_vectoridx[videoid]: videoid
for videoid in self.videoid_to_vectoridx
}
assert len(self.vectoridx_to_videoid) \
== len(self.videoid_to_vectoridx)
query_hidden_states = []
vector_ids = []
for video_id in video_ids:
vector_id = self.videoid_to_vectoridx[video_id]
vector_ids.append(vector_id)
query_hidden_state = self.db.reconstruct(vector_id)
query_hidden_states.append(query_hidden_state)
query_hidden_states = np.stack(query_hidden_states)
# MultilingualFaissDataset uses the following; not sure the reason.
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
_, index = self.db.search(query_hidden_states, retri_factor)
outputs = []
for sample_idx, sample in enumerate(index):
# the first video_id is always the video itself.
cands = [video_ids[sample_idx]]
for vector_idx in sample:
if vector_idx >= 0 \
and vector_ids[sample_idx] != vector_idx:
cands.append(
self.vectoridx_to_videoid[vector_idx]
)
outputs.append(cands)
return outputs
class VectorRetrieverDM(VectorRetriever):
"""
with direct map.
How2 Video Retriver.
Reference usage of FAISS:
https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
"""
def __init__(
self,
hidden_size,
cent,
db_type,
examples_per_cent_to_train
):
super().__init__(
hidden_size, cent, db_type, examples_per_cent_to_train)
self.make_direct_maps_done = False
def make_direct_maps(self):
faiss.downcast_index(self.db).make_direct_map()
self.make_direct_maps_done = True
def search(
self,
query_hidden_states,
orig_dist,
):
if len(self.videoid_to_vectoridx) != self.db.ntotal:
raise ValueError(
len(self.videoid_to_vectoridx),
self.db.ntotal
)
if not self.make_direct_maps_done:
self.make_direct_maps()
if self.vectoridx_to_videoid is None:
self.vectoridx_to_videoid = {
self.videoid_to_vectoridx[videoid]: videoid
for videoid in self.videoid_to_vectoridx
}
assert len(self.vectoridx_to_videoid) \
== len(self.videoid_to_vectoridx)
# MultilingualFaissDataset uses the following; not sure the reason.
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
queried_dist, index = self.db.search(query_hidden_states, 1)
outputs = []
for sample_idx, sample in enumerate(index):
# and queried_dist[sample_idx] < thres \
if sample >= 0 \
and queried_dist[sample_idx] < orig_dist[sample_idx]:
outputs.append(self.vectoridx_to_videoid[sample])
else:
outputs.append(None)
return outputs
def search_by_video_ids(
self,
video_ids,
retri_factor=8
):
if len(self.videoid_to_vectoridx) != self.db.ntotal:
raise ValueError(
len(self.videoid_to_vectoridx),
self.db.ntotal
)
if not self.make_direct_maps_done:
self.make_direct_maps()
if self.vectoridx_to_videoid is None:
self.vectoridx_to_videoid = {
self.videoid_to_vectoridx[videoid]: videoid
for videoid in self.videoid_to_vectoridx
}
assert len(self.vectoridx_to_videoid) \
== len(self.videoid_to_vectoridx)
query_hidden_states = []
vector_ids = []
for video_id in video_ids:
vector_id = self.videoid_to_vectoridx[video_id]
vector_ids.append(vector_id)
query_hidden_state = self.db.reconstruct(vector_id)
query_hidden_states.append(query_hidden_state)
query_hidden_states = np.stack(query_hidden_states)
# MultilingualFaissDataset uses the following; not sure the reason.
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
_, index = self.db.search(query_hidden_states, retri_factor)
outputs = []
for sample_idx, sample in enumerate(index):
# the first video_id is always the video itself.
cands = [video_ids[sample_idx]]
for vector_idx in sample:
if vector_idx >= 0 \
and vector_ids[sample_idx] != vector_idx:
cands.append(
self.vectoridx_to_videoid[vector_idx]
)
outputs.append(cands)
return outputs
class MMVectorRetriever(VectorRetrieverDM):
"""
multimodal vector retriver:
text retrieve video or video retrieve text.
"""
def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
super().__init__(
hidden_size, cent, db_type, examples_per_cent_to_train)
video_db = self.db
super().__init__(
hidden_size, cent, db_type, examples_per_cent_to_train)
text_db = self.db
self.db = {"video": video_db, "text": text_db}
self.video_to_videoid = defaultdict(list)
def __len__(self):
assert self.db["video"].ntotal == self.db["text"].ntotal
return self.db["video"].ntotal
def make_direct_maps(self):
faiss.downcast_index(self.db["video"]).make_direct_map()
faiss.downcast_index(self.db["text"]).make_direct_map()
def save(self, out_dir):
faiss.write_index(
self.db["video"],
os.path.join(out_dir, "video_faiss_idx")
)
faiss.write_index(
self.db["text"],
os.path.join(out_dir, "text_faiss_idx")
)
with open(
os.path.join(
out_dir, "videoid_to_vectoridx.pkl"),
"wb") as fw:
pickle.dump(
self.videoid_to_vectoridx, fw,
protocol=pickle.HIGHEST_PROTOCOL
)
def load(self, out_dir):
fn = os.path.join(out_dir, "video_faiss_idx")
video_db = faiss.read_index(fn)
fn = os.path.join(out_dir, "text_faiss_idx")
text_db = faiss.read_index(fn)
self.db = {"video": video_db, "text": text_db}
with open(
os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
self.videoid_to_vectoridx = pickle.load(fr)
self.video_to_videoid = defaultdict(list)
def add(self, hidden_states, video_ids):
"""hidden_states is a pair `(video, text)`"""
assert len(hidden_states) == len(video_ids), "{}, {}".format(
str(len(hidden_states)), str(len(video_ids)))
assert len(hidden_states.shape) == 3
assert len(self.video_to_videoid) == 0
valid_idx = []
for idx, video_id in enumerate(video_ids):
if video_id not in self.videoid_to_vectoridx:
valid_idx.append(idx)
self.videoid_to_vectoridx[video_id] = \
len(self.videoid_to_vectoridx)
batch_size = hidden_states.shape[0]
hidden_states = hidden_states[valid_idx]
hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy()
if not self.db["video"].is_trained:
self.train_cache.append(hidden_states)
train_len = batch_size * len(self.train_cache)
if train_len < self.train_thres:
return
hidden_states = np.concatenate(self.train_cache, axis=1)
del self.train_cache
self.db["video"].train(hidden_states[0, :self.train_thres])
self.db["text"].train(hidden_states[1, :self.train_thres])
self.db["video"].add(hidden_states[0])
self.db["text"].add(hidden_states[1])
def get_clips_by_video_id(self, video_id):
if not self.video_to_videoid:
for video_id, video_clip, text_clip in self.videoid_to_vectoridx:
self.video_to_videoid[video_id].append(
(video_id, video_clip, text_clip))
return self.video_to_videoid[video_id]
def search(
self,
video_ids,
target_modality,
retri_factor=8
):
if len(self.videoid_to_vectoridx) != len(self):
raise ValueError(
len(self.videoid_to_vectoridx),
len(self)
)
if not self.make_direct_maps_done:
self.make_direct_maps()
if self.vectoridx_to_videoid is None:
self.vectoridx_to_videoid = {
self.videoid_to_vectoridx[videoid]: videoid
for videoid in self.videoid_to_vectoridx
}
assert len(self.vectoridx_to_videoid) \
== len(self.videoid_to_vectoridx)
src_modality = "text" if target_modality == "video" else "video"
query_hidden_states = []
vector_ids = []
for video_id in video_ids:
vector_id = self.videoid_to_vectoridx[video_id]
vector_ids.append(vector_id)
query_hidden_state = self.db[src_modality].reconstruct(vector_id)
query_hidden_states.append(query_hidden_state)
query_hidden_states = np.stack(query_hidden_states)
# MultilingualFaissDataset uses the following; not sure the reason.
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
_, index = self.db[target_modality].search(
query_hidden_states, retri_factor)
outputs = []
for sample_idx, sample in enumerate(index):
cands = []
for vector_idx in sample:
if vector_idx >= 0:
cands.append(
self.vectoridx_to_videoid[vector_idx]
)
outputs.append(cands)
return outputs
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
import os
import numpy as np
import pickle
from . import retri
from ..utils import get_local_rank
class VectorPool(object):
"""
Base class of retrieval space.
"""
def __init__(self, config):
from transformers import AutoConfig
self.hidden_size = AutoConfig.from_pretrained(
config.dataset.bert_name).hidden_size
self.retriever_cls = getattr(retri, config.retriever_cls)
def __call__(self, sample, **kwargs):
raise NotImplementedError
def build_retriver(
self,
retriever_cls=None,
hidden_size=None,
centroids=512,
db_type="flatl2",
examples_per_cent_to_train=48
):
"""merge results from multiple gpus and return a retriver.."""
self.retriver = retriever_cls(
hidden_size, centroids, db_type, examples_per_cent_to_train)
return self.retriver
def __repr__(self):
if hasattr(self, "retriver"):
retriver_name = str(len(self.retriver))
else:
retriver_name = "no retriver field yet"
return self.__class__.__name__ \
+ "(" + retriver_name + ")"
class VideoVectorPool(VectorPool):
"""
average clips of a video as video representation.
"""
def __init__(self, config):
super().__init__(config)
self.build_retriver(self.retriever_cls, self.hidden_size)
def __call__(self, sample, subsampling, **kwargs):
hidden_states = (
sample["pooled_video"] + sample["pooled_text"]) / 2.
hidden_states = hidden_states.view(
-1, subsampling,
hidden_states.size(-1))
hidden_states = torch.mean(hidden_states, dim=1)
hidden_states = hidden_states.cpu().detach().numpy()
video_ids = []
for offset_idx, video_id in enumerate(sample["video_id"]):
if isinstance(video_id, tuple) and len(video_id) == 3:
# a sharded video_id.
video_id = video_id[0]
video_ids.append(video_id)
assert len(video_ids) == len(hidden_states)
self.retriver.add(
hidden_states.astype("float32"),
video_ids
)
class DistributedVectorPool(VectorPool):
"""
support sync of multiple gpus/nodes.
"""
def __init__(self, config):
super().__init__(config)
self.out_dir = os.path.join(
config.fairseq.checkpoint.save_dir,
"retri")
os.makedirs(self.out_dir, exist_ok=True)
self.hidden_states = []
self.video_ids = []
def build_retriver(
self,
retriever_cls=None,
hidden_size=None,
centroids=4096,
db_type="flatl2",
examples_per_cent_to_train=48
):
if retriever_cls is None:
retriever_cls = self.retriever_cls
if hidden_size is None:
hidden_size = self.hidden_size
"""merge results from multiple gpus and return a retriver.."""
if torch.distributed.is_initialized():
self.save()
# sync saving.
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
else:
world_size = 1
self.retriver = retriever_cls(
hidden_size, centroids, db_type, examples_per_cent_to_train)
# each gpu process has its own retriever.
for local_rank in range(world_size):
if get_local_rank() == 0:
print("load local_rank", local_rank)
hidden_states, video_ids = self.load(local_rank)
hidden_states = hidden_states.astype("float32")
self.retriver.add(hidden_states, video_ids)
return self.retriver
def load(self, local_rank):
hidden_states = np.load(
os.path.join(
self.out_dir,
"hidden_state" + str(local_rank) + ".npy"
)
)
with open(
os.path.join(
self.out_dir, "video_id" + str(local_rank) + ".pkl"),
"rb") as fr:
video_ids = pickle.load(fr)
return hidden_states, video_ids
def save(self):
hidden_states = np.vstack(self.hidden_states)
assert len(hidden_states) == len(self.video_ids), "{}, {}".format(
len(hidden_states),
len(self.video_ids)
)
local_rank = torch.distributed.get_rank() \
if torch.distributed.is_initialized() else 0
np.save(
os.path.join(
self.out_dir,
"hidden_state" + str(local_rank) + ".npy"),
hidden_states)
with open(
os.path.join(
self.out_dir,
"video_id" + str(local_rank) + ".pkl"),
"wb") as fw:
pickle.dump(
self.video_ids,
fw,
protocol=pickle.HIGHEST_PROTOCOL
)
class DistributedVideoVectorPool(DistributedVectorPool):
"""
average clips of a video as video representation.
"""
def __call__(self, sample, subsampling, **kwargs):
hidden_states = (
sample["pooled_video"] + sample["pooled_text"]) / 2.
hidden_states = hidden_states.view(
-1, subsampling,
hidden_states.size(-1))
hidden_states = torch.mean(hidden_states, dim=1)
hidden_states = hidden_states.cpu().detach().numpy()
video_ids = []
for offset_idx, video_id in enumerate(sample["video_id"]):
if isinstance(video_id, tuple) and len(video_id) == 3:
# a sharded video_id.
video_id = video_id[0]
video_ids.append(video_id)
assert len(video_ids) == len(hidden_states)
self.hidden_states.append(hidden_states)
self.video_ids.extend(video_ids)
# ------------ the following are deprecated --------------
class TextClipVectorPool(VectorPool):
def __init__(self, config):
from transformers import AutoConfig
hidden_size = AutoConfig.from_pretrained(
config.dataset.bert_name).hidden_size
retriever_cls = getattr(retri, config.retriever_cls)
self.build_retriver(retriever_cls, hidden_size)
def __call__(self, sample, **kwargs):
clip_meta = sample["clip_meta"].cpu()
assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5]))
text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]]
if hasattr(self, "retriver"):
# build_retriver is called.
self.retriver.add(
sample["pooled_text"].cpu().numpy().astype("float32"),
text_meta
)
else:
raise NotImplementedError
class MMClipVectorPool(VectorPool):
"""
Multimodal Clip-level vector pool.
"""
def __init__(self, out_dir):
"""use hidden_states to store `(video, text)`."""
"""use video_ids to store `(video_id, start, end)`."""
super().__init__(out_dir)
def __call__(self, sample, **kwargs):
pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy()
pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy()
self.hidden_states.append(
np.concatenate([pooled_video, pooled_text], axis=1)
)
video_starts = sample["video_start"].cpu()
video_ends = sample["video_end"].cpu()
assert torch.all(torch.le(video_starts, video_ends))
text_starts = sample["text_start"].cpu()
text_ends = sample["text_end"].cpu()
assert torch.all(torch.le(text_starts, text_ends))
subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"])
video_ids = [video_id for video_id in sample["video_id"]
for _ in range(subsample_size)
]
for video_id, video_start, video_end, text_start, text_end in zip(
video_ids, video_starts, video_ends, text_starts, text_ends):
self.video_ids.append((
video_id,
(int(video_start), int(video_end)),
(int(text_start), int(text_end))
))
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .processor import *
from .how2processor import *
from .how2retriprocessor import *
from .dsprocessor import *
try:
from .rawvideoprocessor import *
from .codecprocessor import *
from .webvidprocessor import *
from .expprocessor import *
from .exphow2processor import *
from .exphow2retriprocessor import *
from .expcodecprocessor import *
from .expfeatureencoder import *
from .expdsprocessor import *
except ImportError:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
import json
import pickle
from tqdm import tqdm
import os
import numpy as np
class CaptionDedupProcessor(object):
"""remove overlapping of caption sentences(clip).
Some statistics:
caption:
{'t_clip_len': 246.6448431320854,
'video_len': 281.09174795676245,
'clip_tps': 0.8841283727427481,
'video_tps': 0.7821156477732097,
'min_clip_len': 0.0,
'max_clip_len': 398.3,
'mean_clip_len': 3.196580003006861,
'num_clip': 77.15897706301081}
raw_caption:
{'t_clip_len': 238.95908778424115,
'video_len': 267.5914859862507,
'clip_tps': 2.4941363624267963,
'video_tps': 2.258989769647173,
'min_clip_len': 0.0,
'max_clip_len': 398.3,
'mean_clip_len': 3.0537954186814265,
'num_clip': 78.24986779481756}
"""
def __init__(self, pkl_file):
with open(pkl_file, "rb") as fd:
self.data = pickle.load(fd)
self.stat = {
"t_clip_len": [],
"video_len": [],
"clip_tps": [],
"video_tps": [],
"clip_len": [],
}
def __call__(self):
for idx, video_id in enumerate(tqdm(self.data)):
caption = json.loads(self.data[video_id])
caption = self._dedup(caption)
if idx < 4096: # for the first 4096 examples, compute the statistics.
self.save_stat(video_id, caption)
self.data[video_id] = json.dumps(caption)
self.print_stat()
def single(self, video_id):
caption = json.loads(self.data[video_id])
for clip_idx, (start, end, text) in enumerate(
zip(caption["start"], caption["end"], caption["text"])
):
print(start, end, text)
print("@" * 100)
caption = self._dedup(caption)
for clip_idx, (start, end, text) in enumerate(
zip(caption["start"], caption["end"], caption["text"])
):
print(start, end, text)
print("#" * 100)
self.save_stat(video_id, caption)
self.print_stat()
def finalize(self, tgt_fn):
with open(tgt_fn, "wb") as fw:
pickle.dump(self.data, fw, pickle.HIGHEST_PROTOCOL)
def save_stat(self, video_id, caption):
video_fn = os.path.join(
"data/feat/feat_how2_s3d", video_id + ".npy"
)
if os.path.isfile(video_fn):
with open(video_fn, "rb", 1) as fr: # 24 is the buffer size. buffered
version = np.lib.format.read_magic(fr)
shape, fortran, dtype = np.lib.format._read_array_header(fr, version)
video_len = shape[0]
t_clip_len = 0.0
t_tokens = 0
for idx, (start, end, text) in enumerate(
zip(caption["start"], caption["end"], caption["text"])
):
clip_len = (
(end - max(caption["end"][idx - 1], start))
if idx > 0
else end - start
)
t_clip_len += clip_len
t_tokens += len(text.split(" "))
self.stat["clip_len"].append(clip_len)
self.stat["t_clip_len"].append(t_clip_len)
self.stat["video_len"].append(video_len)
self.stat["clip_tps"].append(t_tokens / t_clip_len)
self.stat["video_tps"].append(t_tokens / video_len)
def print_stat(self):
result = {
"t_clip_len": np.mean(self.stat["t_clip_len"]),
"video_len": np.mean(self.stat["video_len"]),
"clip_tps": np.mean(self.stat["clip_tps"]),
"video_tps": np.mean(self.stat["video_tps"]),
"min_clip_len": min(self.stat["clip_len"]),
"max_clip_len": max(self.stat["clip_len"]),
"mean_clip_len": np.mean(self.stat["clip_len"]),
"num_clip": len(self.stat["clip_len"]) / len(self.stat["video_tps"]),
}
print(result)
def _dedup(self, caption):
def random_merge(end_idx, start, end, text, starts, ends, texts):
if random.random() > 0.5:
# print(clip_idx, "[PARTIAL INTO PREV]", end_idx)
# overlapped part goes to the end of previous.
ends[-1] = max(ends[-1], start) # ?
rest_text = text[end_idx:].strip()
if rest_text:
starts.append(max(ends[-1], start))
ends.append(max(end, starts[-1]))
texts.append(rest_text)
else: # goes to the beginning of the current.
# strip the previous.
left_text = texts[-1][:-end_idx].strip()
if left_text:
# print(clip_idx, "[PREV PARTIAL INTO CUR]", end_idx)
ends[-1] = min(ends[-1], start)
texts[-1] = left_text
else:
# print(clip_idx, "[PREV LEFT NOTHING ALL INTO CUR]", end_idx)
starts.pop(-1)
ends.pop(-1)
texts.pop(-1)
starts.append(start)
ends.append(end)
texts.append(text)
starts, ends, texts = [], [], []
for clip_idx, (start, end, text) in enumerate(
zip(caption["start"], caption["end"], caption["text"])
):
if not isinstance(text, str):
continue
text = text.replace("\n", " ").strip()
if len(text) == 0:
continue
starts.append(start)
ends.append(end)
texts.append(text)
break
for clip_idx, (start, end, text) in enumerate(
zip(
caption["start"][clip_idx + 1:],
caption["end"][clip_idx + 1:],
caption["text"][clip_idx + 1:],
)
):
if not isinstance(text, str):
continue
text = text.replace("\n", " ").strip()
if len(text) == 0:
continue
# print(clip_idx, texts[-5:])
# print(clip_idx, start, end, text)
if texts[-1].endswith(text): # subset of prev caption -> merge
# print(clip_idx, "[MERGE INTO PREV]")
ends[-1] = max(ends[-1], end)
elif text.startswith(texts[-1]): # superset of prev caption -> merge
# print(clip_idx, "[PREV MERGE INTO CUR]")
texts[-1] = text
starts[-1] = min(starts[-1], start)
ends[-1] = max(ends[-1], end)
else: # overlapping or non-overlapping.
for end_idx in range(1, len(text) + 1):
if texts[-1].endswith(text[:end_idx]):
random_merge(end_idx, start, end, text, starts, ends, texts)
break
else:
starts.append(start)
ends.append(end)
texts.append(text)
assert (ends[-1] + 0.001) >= starts[-1] and len(
texts[-1]
) > 0, "{} {} {} <- {} {} {}, {} {} {}".format(
str(starts[-1]),
str(ends[-1]),
texts[-1],
caption["start"][clip_idx - 1],
caption["end"][clip_idx - 1],
caption["text"][clip_idx - 1],
str(start),
str(end),
text,
)
return {"start": starts, "end": ends, "text": texts}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="dedup how2 caption")
parser.add_argument('--how2dir', default="data/how2")
args = parser.parse_args()
raw_caption_json = os.path.join(args.how2dir, "raw_caption.json")
raw_caption_pickle = os.path.join(args.how2dir, "raw_caption.pkl")
raw_caption_dedup_pickle = os.path.join(args.how2dir, "raw_caption_dedup.pkl")
def convert_to_pickle(src_fn, tgt_fn):
with open(src_fn) as fd:
captions = json.load(fd)
for video_id in captions:
captions[video_id] = json.dumps(captions[video_id])
with open(tgt_fn, "wb") as fw:
pickle.dump(captions, fw, pickle.HIGHEST_PROTOCOL)
if not os.path.isfile(raw_caption_pickle):
convert_to_pickle(raw_caption_json, raw_caption_pickle)
deduper = CaptionDedupProcessor(raw_caption_pickle)
deduper()
deduper.finalize(raw_caption_dedup_pickle)
"""
# demo
deduper = CaptionDedupProcessor("data/how2/raw_caption.pkl")
deduper.single("HfIeQ9pzL5U")
"""
# Copyright (c) Facebook, Inc. All Rights Reserved
"""
Processors for all downstream (ds) tasks.
"""
import json
import os
import pickle
import random
import math
import numpy as np
import torch
from collections import defaultdict
from .processor import (
MetaProcessor,
VideoProcessor,
TextProcessor,
Aligner,
MMAttentionMask2DProcessor,
)
from .how2processor import TextGenerationProcessor
# ------------- A General Aligner for all downstream tasks-----------------
class DSAligner(Aligner):
"""
Downstream (DS) aligner shared by all datasets.
"""
def __call__(self, video_id, video_feature, text_feature, wps=0.7):
# random sample a starting sec for video.
video_start = 0
video_end = min(len(video_feature), self.max_video_len)
# the whole sequence is a single clip.
video_clips = {"start": [video_start], "end": [video_end]}
text_feature = {
"cap": [text_feature],
"start": [video_start],
"end": [len(text_feature) / wps],
}
text_clip_indexs = [0]
vfeats, vmasks = self._build_video_seq(
video_feature, video_clips
)
caps, cmasks = self._build_text_seq(
text_feature, text_clip_indexs
)
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
"video_id": video_id,
}
class NLGTextProcessor(TextProcessor):
"""
Also return the original text as ref.
"""
def __call__(self, text_id):
return super().__call__(text_id), text_id
class DSNLGAligner(DSAligner):
"""extend with the capability of 2d mask for generation."""
def __init__(self, config):
super().__init__(config)
self.attnmasker = MMAttentionMask2DProcessor()
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
self.bert_name, use_fast=self.use_fast,
bos_token="[CLS]", eos_token="[SEP]"
)
self.tokenizer = tokenizer
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
self.textgen = TextGenerationProcessor(tokenizer)
def __call__(self, video_id, video_feature, text_feature):
output = super().__call__(video_id, video_feature, text_feature[0])
if self.split == "test":
# output.update({"ref": text_feature[1]})
output.update({"ref": self.tokenizer.decode(
output["caps"], skip_special_tokens=True)})
text_label = output["caps"]
cmasks = torch.BoolTensor([1] * text_label.size(0))
caps = torch.LongTensor([
self.cls_token_id,
self.sep_token_id,
self.bos_token_id])
else:
caps, text_label = self.textgen(output["caps"])
cmasks = output["cmasks"]
attention_mask = self.attnmasker(
output["vmasks"], cmasks, "textgen")
output.update({
"caps": caps,
"cmasks": cmasks,
"text_label": text_label,
"attention_mask": attention_mask,
})
return output
# -------------------- MSRVTT ------------------------
class MSRVTTMetaProcessor(MetaProcessor):
"""MSRVTT dataset.
reference: `howto100m/msrvtt_dataloader.py`
"""
def __init__(self, config):
super().__init__(config)
import pandas as pd
data = pd.read_csv(self._get_split_path(config))
# TODO: add a text1ka flag.
if config.split == "train" \
and config.full_test_path is not None \
and config.jsfusion_path is not None:
# add testing videos from full_test_path not used by jfusion.
additional_data = pd.read_csv(config.full_test_path)
jsfusion_data = pd.read_csv(config.jsfusion_path)
for video_id in additional_data["video_id"]:
if video_id not in jsfusion_data["video_id"].values:
data = data.append(
{"video_id": video_id}, ignore_index=True)
if config.dup is not None and config.split == "train":
data = data.append([data] * (config.dup - 1), ignore_index=True)
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
"""slightly modify with if condition to combine train/test."""
vid, sentence = None, None
vid = self.data["video_id"].values[idx]
if "sentence" in self.data: # for testing.
sentence = self.data["sentence"].values[idx]
else: # for training.
sentence = vid
return vid, sentence
class MSRVTTTextProcessor(TextProcessor):
"""MSRVTT dataset.
reference: `msrvtt_dataloader.py` `MSRVTT_TrainDataLoader`.
TODO (huxu): add max_words.
"""
def __init__(self, config):
super().__init__(config)
self.sentences = None
if config.json_path is not None and config.split == "train":
with open(config.json_path) as fd:
self.data = json.load(fd)
self.sentences = defaultdict(list)
for s in self.data["sentences"]:
self.sentences[s["video_id"]].append(s["caption"])
def __call__(self, text_id):
if self.sentences is not None:
rind = random.randint(0, len(self.sentences[text_id]) - 1)
sentence = self.sentences[text_id][rind]
else:
sentence = text_id
caption = self.tokenizer(sentence, add_special_tokens=False)
return caption["input_ids"]
class MSRVTTNLGTextProcessor(MSRVTTTextProcessor):
"""TODO: change dsaligner and merge to avoid any NLG text processor."""
def __call__(self, text_id):
if self.sentences is not None:
rind = random.randint(0, len(self.sentences[text_id]) - 1)
sentence = self.sentences[text_id][rind]
else:
sentence = text_id
caption = self.tokenizer(sentence, add_special_tokens=False)
return caption["input_ids"], sentence
class MSRVTTQAMetaProcessor(MetaProcessor):
"""MSRVTT-QA: retrieval-based multi-choice QA from JSFusion dataset.
For simplicity, we use the train retrieval model.
reference: `https://github.com/yj-yu/lsmdc`
"""
def __init__(self, config):
super().__init__(config)
import pandas as pd
csv_data = pd.read_csv(self._get_split_path(config), sep="\t")
data = []
for video_id, a1, a2, a3, a4, a5, answer in zip(
csv_data["vid_key"].values,
csv_data["a1"].values,
csv_data["a2"].values,
csv_data["a3"].values,
csv_data["a4"].values,
csv_data["a5"].values,
csv_data["answer"].values):
video_id = video_id.replace("msr", "video")
data.append((video_id, (answer, [a1, a2, a3, a4, a5])))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class MSRVTTQATextProcessor(TextProcessor):
"""MSRVTT-QA dataset.
text_ans is of format `(answer, [a1, a2, a3, a4, a5])`.
"""
def __call__(self, text_ans):
for ans_idx, ans in enumerate(text_ans[1]):
if isinstance(ans, str):
text_ans[1][ans_idx] = self.tokenizer(ans, add_special_tokens=False)["input_ids"]
return text_ans
class MSRVTTQAAligner(DSAligner):
"""MSRVTT dataset.
similar to sample in how2.
we call __call__ multiple times.
"""
def __call__(self, video_id, video_feature, text_feature, wps=0.7):
caps = []
cmasks = []
answer = text_feature[0]
for ans_idx, _text_feature in enumerate(text_feature[1]):
output = super().__call__(
video_id, video_feature, _text_feature, wps)
caps.append(output["caps"])
cmasks.append(output["cmasks"])
output.update({
"caps": torch.stack(caps),
"cmasks": torch.stack(cmasks),
"answers": torch.LongTensor([answer]),
})
return output
# -------------------- Youcook -----------------------
class YoucookMetaProcessor(MetaProcessor):
"""Youcook dataset.
reference: `howto100m/youcook_dataloader.py`
note that the data can be different as the
(1) some videos already in Howto100m are removed.
(2) stop words are removed from caption
TODO (huxu): make a flag to load the original caption.
(see youcookii_annotations_trainval.json).
The max_video_len can be 264 and text can be 64 tokens.
In reality we may not need that long. see projects/task/youcook.yaml
"""
def __init__(self, config):
super().__init__(config)
vfeat_dir = config.vfeat_dir
print(self._get_split_path(config))
with open(self._get_split_path(config), "rb") as fd:
data = pickle.load(fd)
all_valid_video_ids = set(
[os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)]
)
recs = []
video_ids = set()
valid_video_ids = set()
for rec in data: # filter videos not available.
udl_idx = rec["id"].rindex("_")
video_id = rec["id"][:udl_idx]
video_ids.add(video_id)
if video_id in all_valid_video_ids:
valid_video_ids.add(video_id)
recs.append(rec)
print("total video_ids in .pkl", len(video_ids))
print("valid video_ids in .pkl", len(valid_video_ids))
print("please verify {train,val}_list.txt")
data = recs
self.data = data
with open(config.trainval_annotation) as fd:
self.youcook_annotation = json.load(fd)["database"]
if config.use_annotation_text is True:
print("using text in annotation.")
self.use_annotation_caption = True
else:
self.use_annotation_caption = False
def __getitem__(self, idx):
def _get_video_and_caption(rec):
vid = rec["id"]
udl_idx = vid.rindex("_")
video_id, clip_id = vid[:udl_idx], int(vid[udl_idx + 1:])
clip = self.youcook_annotation[video_id]["annotations"][clip_id]
start, end = clip["segment"]
if self.use_annotation_caption:
caption = clip["sentence"]
else:
caption = rec["caption"]
return (video_id, start, end), caption
rec = self.data[idx]
video_info, text_info = _get_video_and_caption(rec)
return video_info, text_info
class YoucookVideoProcessor(VideoProcessor):
"""video_fn is a tuple of (video_id, start, end) now."""
def __call__(self, video_fn):
video_id, start, end = video_fn
feat = np.load(os.path.join(self.vfeat_dir, video_id + ".npy"))
return feat[start:end]
class YoucookNLGMetaProcessor(MetaProcessor):
"""NLG uses the original split:
`train_list.txt` and `val_list.txt`
"""
def __init__(self, config):
super().__init__(config)
vfeat_dir = config.vfeat_dir
print(self._get_split_path(config))
with open(self._get_split_path(config)) as fd:
video_ids = [
line.strip().split("/")[1] for line in fd.readlines()]
print("total video_ids in train/val_list.txt", len(video_ids))
all_valid_video_ids = set(
[os.path.splitext(fn)[0] for fn in os.listdir(vfeat_dir)]
)
video_ids = [
video_id for video_id in video_ids
if video_id in all_valid_video_ids]
print("valid video_ids in train/val_list.txt", len(video_ids))
with open(config.trainval_annotation) as fd:
self.youcook_annotation = json.load(fd)["database"]
data = []
for video_id in video_ids:
for clip in self.youcook_annotation[video_id]["annotations"]:
start, end = clip["segment"]
caption = clip["sentence"]
data.append(((video_id, start, end), caption))
self.data = data
def __getitem__(self, idx):
return self.data[idx]
# --------------------- CrossTask -------------------------
class CrossTaskMetaProcessor(MetaProcessor):
def __init__(self, config):
super().__init__(config)
np.random.seed(0) # deterministic random split.
task_vids = self._get_vids(
config.train_csv_path,
config.vfeat_dir,
config.annotation_path)
val_vids = self._get_vids(
config.val_csv_path,
config.vfeat_dir,
config.annotation_path)
# filter out those task and vids appear in val_vids.
task_vids = {
task: [
vid for vid in vids
if task not in val_vids or vid not in val_vids[task]]
for task, vids in task_vids.items()}
primary_info = self._read_task_info(config.primary_path)
test_tasks = set(primary_info['steps'].keys())
# if args.use_related:
related_info = self._read_task_info(config.related_path)
task_steps = {**primary_info['steps'], **related_info['steps']}
n_steps = {**primary_info['n_steps'], **related_info['n_steps']}
# else:
# task_steps = primary_info['steps']
# n_steps = primary_info['n_steps']
all_tasks = set(n_steps.keys())
# filter and keep task in primary or related.
task_vids = {
task: vids for task, vids in task_vids.items()
if task in all_tasks}
# vocab-by-step matrix (A) and vocab (M)
# (huxu): we do not use BoW.
# A, M = self._get_A(task_steps, share="words")
train_vids, test_vids = self._random_split(
task_vids, test_tasks, config.n_train)
print("train_num_videos", sum(len(vids) for vids in train_vids.values()))
print("test_num_videos", sum(len(vids) for vids in test_vids.values()))
# added by huxu to automatically determine the split.
split_map = {
"train": train_vids,
"valid": test_vids,
"test": test_vids
}
task_vids = split_map[config.split]
self.vids = []
for task, vids in task_vids.items():
self.vids.extend([(task, vid) for vid in vids])
self.task_steps = task_steps
self.n_steps = n_steps
def __getitem__(self, idx):
task, vid = self.vids[idx]
n_steps = self.n_steps[task]
steps = self.task_steps[task]
assert len(steps) == n_steps
return (task, vid, steps, n_steps), (task, vid, steps, n_steps)
def __len__(self):
return len(self.vids)
def _random_split(self, task_vids, test_tasks, n_train):
train_vids = {}
test_vids = {}
for task, vids in task_vids.items():
if task in test_tasks and len(vids) > n_train:
train_vids[task] = np.random.choice(
vids, n_train, replace=False).tolist()
test_vids[task] = [
vid for vid in vids if vid not in train_vids[task]]
else:
train_vids[task] = vids
return train_vids, test_vids
def _get_vids(self, path, vfeat_dir, annotation_path):
"""refactored from
https://github.com/DmZhukov/CrossTask/blob/master/data.py
changes: add `vfeat_dir` to check if the video is available.
add `annotation_path` to check if the video is available.
"""
task_vids = {}
with open(path, 'r') as f:
for line in f:
task, vid, url = line.strip().split(',')
# double check the video is available.
if not os.path.exists(
os.path.join(vfeat_dir, vid + ".npy")):
continue
# double check the annotation is available.
if not os.path.exists(os.path.join(
annotation_path,
task + "_" + vid + ".csv")):
continue
if task not in task_vids:
task_vids[task] = []
task_vids[task].append(vid)
return task_vids
def _read_task_info(self, path):
titles = {}
urls = {}
n_steps = {}
steps = {}
with open(path, 'r') as f:
idx = f.readline()
while idx != '':
idx = idx.strip()
titles[idx] = f.readline().strip()
urls[idx] = f.readline().strip()
n_steps[idx] = int(f.readline().strip())
steps[idx] = f.readline().strip().split(',')
next(f)
idx = f.readline()
return {
'title': titles,
'url': urls,
'n_steps': n_steps,
'steps': steps
}
def _get_A(self, task_steps, share="words"):
raise ValueError("running get_A is not allowed for BERT.")
"""Step-to-component matrices."""
if share == 'words':
# share words
task_step_comps = {
task: [step.split(' ') for step in steps]
for task, steps in task_steps.items()}
elif share == 'task_words':
# share words within same task
task_step_comps = {
task: [[task+'_'+tok for tok in step.split(' ')] for step in steps]
for task, steps in task_steps.items()}
elif share == 'steps':
# share whole step descriptions
task_step_comps = {
task: [[step] for step in steps] for task, steps in task_steps.items()}
else:
# no sharing
task_step_comps = {
task: [[task+'_'+step] for step in steps]
for task, steps in task_steps.items()}
# BERT tokenizer here?
vocab = []
for task, steps in task_step_comps.items():
for step in steps:
vocab.extend(step)
vocab = {comp: m for m, comp in enumerate(set(vocab))}
M = len(vocab)
A = {}
for task, steps in task_step_comps.items():
K = len(steps)
a = torch.zeros(M, K)
for k, step in enumerate(steps):
a[[vocab[comp] for comp in step], k] = 1
a /= a.sum(dim=0)
A[task] = a
return A, M
class CrossTaskVideoProcessor(VideoProcessor):
def __call__(self, video_fn):
task, vid, steps, n_steps = video_fn
video_fn = os.path.join(self.vfeat_dir, vid + ".npy")
feat = np.load(video_fn)
return feat
class CrossTaskTextProcessor(TextProcessor):
def __call__(self, text_id):
task, vid, steps, n_steps = text_id
step_ids = []
for step_str in steps:
step_ids.append(
self.tokenizer(step_str, add_special_tokens=False)["input_ids"]
)
return step_ids
class CrossTaskAligner(Aligner):
"""
TODO: it's not clear yet the formulation of the task; finish this later.
"""
def __init__(self, config):
super().__init__(config)
self.annotation_path = config.annotation_path
self.sliding_window = config.sliding_window
self.sliding_window_size = config.sliding_window_size
def __call__(self, video_id, video_feature, text_feature):
task, vid, steps, n_steps = video_id
annot_path = os.path.join(
self.annotation_path, task + '_' + vid + '.csv')
video_len = len(video_feature)
labels = torch.from_numpy(self._read_assignment(
video_len, n_steps, annot_path)).float()
vfeats, vmasks, targets = [], [], []
# sliding window on video features and targets.
for window_start in range(0, video_len, self.sliding_window):
video_start = 0
video_end = min(video_len - window_start, self.sliding_window_size)
video_clip = {"start": [video_start], "end": [video_end]}
vfeat, vmask = self._build_video_seq(
video_feature[window_start: window_start + video_end],
video_clip
)
target = labels[window_start: window_start + video_end]
assert len(vfeat) >= len(target), "{},{}".format(len(vfeat), len(target))
# TODO: randomly drop all zero targets for training ?
# if self.split == "train" and target.sum() == 0:
# continue
vfeats.append(vfeat)
vmasks.append(vmask)
targets.append(target)
if (video_len - window_start) <= self.sliding_window_size:
break
vfeats = torch.stack(vfeats)
vmasks = torch.stack(vmasks)
targets = torch.cat(targets, dim=0)
caps, cmasks = [], []
for step in text_feature:
step_text_feature = {"start": [0], "end": [1], "cap": [step]}
step_text_clip_index = [0]
cap, cmask = self._build_text_seq(
step_text_feature, step_text_clip_index
)
caps.append(cap)
cmasks.append(cmask)
caps = torch.stack(caps)
cmasks = torch.stack(cmasks)
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats, # X for original code.
"vmasks": vmasks,
"targets": targets,
"video_id": vid,
"task": task,
"video_len": video_len # for later checking.
}
def _read_assignment(self, T, K, path):
"""
refactored from https://github.com/DmZhukov/CrossTask/blob/master/data.py
Howto interpret contraints on loss that is going to be minimized:
lambd is a big number;
self.lambd * C is a big number for all valid position (csv stores invalids)
def forward(self, O, Y, C):
return (Y*(self.lambd * C - self.lsm(O))).mean(dim=0).sum()
This will load the csv file and fill-in the step col from start to end rows.
"""
Y = np.zeros([T, K], dtype=np.uint8)
with open(path, 'r') as f:
for line in f:
step, start, end = line.strip().split(',')
start = int(math.floor(float(start)))
end = int(math.ceil(float(end)))
step = int(step) - 1
Y[start:end, step] = 1
return Y
# --------------------- COIN -------------------------
class MetaTextBinarizer(Aligner):
def __call__(self, text_feature):
text_feature = {
"cap": [text_feature],
"start": [0.],
"end": [100.],
}
text_clip_indexs = [0]
caps, cmasks = self._build_text_seq(
text_feature, text_clip_indexs
)
return {"caps": caps, "cmasks": cmasks}
class COINActionSegmentationMetaProcessor(MetaProcessor):
split_map = {
"train": "training",
"valid": "testing",
"test": "testing",
}
def __init__(self, config):
super().__init__(config)
with open(self._get_split_path(config)) as fr:
database = json.load(fr)["database"]
id2label = {}
data = []
# filter the data by split.
for video_id, rec in database.items():
# always use testing to determine label_set
if rec["subset"] == "testing":
for segment in rec["annotation"]:
id2label[int(segment["id"])] = segment["label"]
# text_labels is used for ZS setting
self.text_labels = ["none"] * len(id2label)
for label_id in id2label:
self.text_labels[label_id-1] = id2label[label_id]
id2label[0] = "O"
print("num of labels", len(id2label))
for video_id, rec in database.items():
if not os.path.isfile(os.path.join(config.vfeat_dir, video_id + ".npy")):
continue
if rec["subset"] == COINActionSegmentationMetaProcessor.split_map[self.split]:
starts, ends, labels = [], [], []
for segment in rec["annotation"]:
start, end = segment["segment"]
label = int(segment["id"])
starts.append(start)
ends.append(end)
labels.append(label)
data.append(
(video_id, {"start": starts, "end": ends, "label": labels}))
self.data = data
def meta_text_labels(self, config):
from transformers import default_data_collator
from ..utils import get_local_rank
text_processor = TextProcessor(config)
binarizer = MetaTextBinarizer(config)
# TODO: add prompts to .yaml.
text_labels = [label for label in self.text_labels]
if get_local_rank() == 0:
print(text_labels)
outputs = []
for text_label in text_labels:
text_feature = text_processor(text_label)
outputs.append(binarizer(text_feature))
return default_data_collator(outputs)
def __getitem__(self, idx):
return self.data[idx]
class COINActionSegmentationTextProcessor(TextProcessor):
def __call__(self, text_label):
return text_label
class COINActionSegmentationAligner(Aligner):
def __init__(self, config):
super().__init__(config)
self.sliding_window = config.sliding_window
self.sliding_window_size = config.sliding_window_size
def __call__(self, video_id, video_feature, text_feature):
starts, ends, label_ids = text_feature["start"], text_feature["end"], text_feature["label"]
# sliding window.
video_len = len(video_feature)
vfeats, vmasks, targets = [], [], []
# sliding window on video features and targets.
for window_start in range(0, video_len, self.sliding_window):
video_start = 0
video_end = min(video_len - window_start, self.sliding_window_size)
video_clip = {"start": [video_start], "end": [video_end]}
vfeat, vmask = self._build_video_seq(
video_feature[window_start: window_start + video_end],
video_clip
)
# covers video length only.
target = torch.full_like(vmask, -100, dtype=torch.long)
target[vmask] = 0
for start, end, label_id in zip(starts, ends, label_ids):
if (window_start < end) and (start < (window_start + video_end)):
start_offset = max(0, math.floor(start) - window_start)
end_offset = min(video_end, math.ceil(end) - window_start)
target[start_offset:end_offset] = label_id
vfeats.append(vfeat)
vmasks.append(vmask)
targets.append(target)
if (video_len - window_start) <= self.sliding_window_size:
break
vfeats = torch.stack(vfeats)
vmasks = torch.stack(vmasks)
targets = torch.stack(targets)
video_targets = torch.full((video_len,), 0)
for start, end, label_id in zip(starts, ends, label_ids):
start_offset = max(0, math.floor(start))
end_offset = min(video_len, math.ceil(end))
video_targets[start_offset:end_offset] = label_id
caps = torch.LongTensor(
[[self.cls_token_id, self.sep_token_id,
self.pad_token_id, self.sep_token_id]],
).repeat(vfeats.size(0), 1)
cmasks = torch.BoolTensor(
[[0, 1, 0, 1]] # pad are valid for attention.
).repeat(vfeats.size(0), 1)
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats, # X for original code.
"vmasks": vmasks,
"targets": targets,
"video_id": video_id,
"video_len": video_len, # for later checking.
"video_targets": video_targets
}
class DiDeMoMetaProcessor(MetaProcessor):
"""reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
"""
def __init__(self, config):
super().__init__(config)
assert "test" in self._get_split_path(config), "DiDeMo only supports zero-shot testing for now."
with open(self._get_split_path(config)) as data_file:
json_data = json.load(data_file)
data = []
for record in json_data:
data.append((record["video"], record["description"]))
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DiDeMoTextProcessor(TextProcessor):
"""reference: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py
https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/data_processing.py
"""
def __call__(self, text):
return self.tokenizer(text, add_special_tokens=False)["input_ids"]
class DiDeMoAligner(DSAligner):
"""
check video length.
"""
def __call__(self, video_id, video_feature, text_feature):
# print(video_feature.shape[0])
return super().__call__(video_id, video_feature, text_feature)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
import math
import pickle
import random
import os
import numpy as np
from collections import deque
from typing import Optional, Tuple, List
from .processor import (
Processor,
MetaProcessor,
TextProcessor,
Aligner,
MMAttentionMask2DProcessor
)
from ..utils import ShardedTensor
class How2MetaProcessor(MetaProcessor):
def __init__(self, config):
super().__init__(config)
path = self._get_split_path(config)
with open(path) as fd:
self.data = [line.strip() for line in fd]
def __getitem__(self, idx):
video_id = self.data[idx]
return video_id, video_id
class ShardedHow2MetaProcessor(How2MetaProcessor):
def __init__(self, config):
super().__init__(config)
self.split = str(config.split)
self.vfeat_dir = config.vfeat_dir
self._init_shard()
def _init_shard(self):
if self.split == "train":
meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl")
with open(meta_fn, "rb") as fr:
meta = pickle.load(fr)
elif self.split == "valid":
meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
with open(meta_fn, "rb") as fr:
meta = pickle.load(fr)
elif self.split == "test":
print("use how2 val as test.")
meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
with open(meta_fn, "rb") as fr:
meta = pickle.load(fr)
else:
raise ValueError("unsupported for MetaProcessor:", self.split)
video_id_to_shard = {}
for shard_id in meta:
for video_idx, video_id in enumerate(meta[shard_id]):
video_id_to_shard[video_id] = (shard_id, video_idx)
self.video_id_to_shard = video_id_to_shard
def __getitem__(self, idx):
video_id, video_id = super().__getitem__(idx)
shard_id, shard_idx = self.video_id_to_shard[video_id]
meta = (video_id, idx, shard_id, shard_idx)
return meta, meta
class ShardedVideoProcessor(Processor):
"""
mmaped shards of numpy video features.
"""
def __init__(self, config):
self.split = str(config.split)
self.vfeat_dir = config.vfeat_dir
def __call__(self, video_id):
_, _, shard_id, video_idx = video_id
if self.split == "train":
shard = ShardedTensor.load(
os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)),
"r"
)
elif self.split == "valid":
shard = ShardedTensor.load(
os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
"r"
)
elif self.split == "test":
shard = ShardedTensor.load(
os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
"r"
)
else:
raise ValueError("unknown split", self.split)
feat = shard[video_idx]
return feat
class ShardedTextProcessor(Processor):
def __init__(self, config):
self.tfeat_dir = str(config.tfeat_dir)
self.split = str(config.split)
def __call__(self, video_id):
_, _, shard_id, shard_idx = video_id
if self.split == "train":
target_path = self.tfeat_dir + "train" + "_" + str(shard_id)
elif self.split == "valid":
target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
elif self.split == "test":
target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
else:
raise ValueError("unknown split", self.split)
startend = ShardedTensor.load(
target_path + ".startends", "r")[shard_idx]
cap_ids = ShardedTensor.load(
target_path + ".caps_ids", "r")[shard_idx]
cap = []
for clip_idx in range(len(cap_ids)):
clip = cap_ids[clip_idx]
cap.append(clip[clip != -1].tolist())
start, end = startend[:, 0].tolist(), startend[:, 1].tolist()
return {"start": start, "end": end, "cap": cap}
class FixedLenAligner(Aligner):
"""
In the model we assume text is on the left (closer to BERT formulation)
and video is on the right.
We fix the total length of text + video.
max_video_len is in number of secs.
max_text_len is in number of tokens.
special tokens formats:
we use the format [CLS] [SEP] text tokens [SEP] [PAD] ...
[CLS] will be splitted out into:
[CLS] video tokens [SEP] text tokens [SEP] [PAD] ...
token_type_ids will be generated by the model (for now).
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
so each sequence owns a [SEP] token for no-ops.
"""
def __init__(self, config):
super().__init__(config)
self.text_clip_sampler = TextClipSamplingProcessor(
self.max_len - self.max_video_len - 3
)
"""
decide subsampling:
`config.subsampling` will change batch_size in trainer.
`config.clip_per_video` (used by RetriTask) doesn't
change batch_size in trainer.
"""
subsampling = config.subsampling \
if config.subsampling is not None else None
if config.clip_per_video is not None:
subsampling = config.clip_per_video
self.subsampling = subsampling
def _get_text_maxlen(self):
# use max text len
return self.text_clip_sampler.max_text_len
def __call__(self, video_id, video_feature, text_feature):
from transformers import default_data_collator
video_idx = video_id[1]
if self.subsampling is not None and self.subsampling >= 1:
batch = []
for _ in range(self.subsampling):
centerclip_idx = random.randint(
0, len(text_feature["start"]) - 1)
batch.append(
self.sampling(
video_idx,
video_feature,
text_feature,
centerclip_idx,
self._get_text_maxlen()
))
batch = self.batch_post_processing(batch, video_feature)
batch = default_data_collator(batch)
else:
raise ValueError(
"dataset.subsampling must be >= 1 for efficient video loading.")
batch = self.sampling(video_idx, video_feature, text_feature)
batch = self.batch_post_processing(batch, video_feature)
batch["video_id"] = video_id if isinstance(video_id, str) \
else video_id[0]
# e2e: make sure frame ids is into tensor.
assert torch.is_tensor(batch["vfeats"])
return batch
def sampling(
self,
video_idx,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
text_clip_indexs = self.text_clip_sampler(
text_feature, centerclip_idx,
sampled_max_text_len
)
if isinstance(video_feature, np.ndarray):
video_len = len(video_feature)
else:
video_len = math.ceil(text_feature["end"][-1])
video_end = min(
math.ceil(text_feature["end"][text_clip_indexs[-1]]),
video_len
)
video_start = max(
min(
math.floor(text_feature["start"][text_clip_indexs[0]]),
video_end),
0
)
video_clips = {"start": [video_start], "end": [video_end]}
# tensorize.
vfeats, vmasks = self._build_video_seq(
video_feature, video_clips
)
caps, cmasks = self._build_text_seq(
text_feature, text_clip_indexs
)
text_start = text_clip_indexs[0]
text_end = text_clip_indexs[-1] + 1
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
"video_start": video_start,
"video_end": video_end,
"text_start": text_start,
"text_end": text_end,
}
class VariedLenAligner(FixedLenAligner):
def __init__(self, config):
super().__init__(config)
self.sampled_min_len = config.sampled_min_len
self.sampled_max_len = config.sampled_max_len
def _get_text_maxlen(self):
return random.randint(self.sampled_min_len, self.sampled_max_len)
class StartClipAligner(VariedLenAligner):
def sampling(
self,
video_idx,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
return super().sampling(
video_idx, video_feature, text_feature, 0)
class OverlappedAligner(VariedLenAligner):
"""video clip and text clip has overlappings
but may not be the same start/end."""
def __init__(self, config):
super().__init__(config)
self.sampled_video_min_len = config.sampled_video_min_len
self.sampled_video_max_len = config.sampled_video_max_len
self.video_clip_sampler = VideoClipSamplingProcessor()
def _get_video_maxlen(self):
return random.randint(
self.sampled_video_min_len, self.sampled_video_max_len)
def sampling(
self,
video_idx,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
text_clip_indexs = self.text_clip_sampler(
text_feature, centerclip_idx,
sampled_max_text_len
)
if isinstance(video_feature, np.ndarray):
video_len = len(video_feature)
else:
video_len = math.ceil(text_feature["end"][-1])
low = math.floor(text_feature["start"][text_clip_indexs[0]])
high = math.ceil(text_feature["end"][text_clip_indexs[-1]])
if low < high:
center = random.randint(low, high)
else:
center = int((low + high) // 2)
center = max(0, min(video_feature.shape[0] - 1, center))
assert 0 <= center < video_feature.shape[0]
video_clips = self.video_clip_sampler(
video_len, self._get_video_maxlen(), center
)
video_start = video_clips["start"][0]
video_end = video_clips["end"][0]
# tensorize.
vfeats, vmasks = self._build_video_seq(
video_feature, video_clips
)
caps, cmasks = self._build_text_seq(
text_feature, text_clip_indexs
)
text_start = text_clip_indexs[0]
text_end = text_clip_indexs[-1] + 1
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
"video_start": video_start,
"video_end": video_end,
"text_start": text_start,
"text_end": text_end,
}
class MFMMLMAligner(FixedLenAligner):
"""
`FixedLenAligner` with Masked Language Model and Masked Frame Model.
"""
def __init__(self, config):
super().__init__(config)
keep_prob = config.keep_prob if config.keep_prob is not None else 1.0
self.text_clip_sampler = TextClipSamplingProcessor(
self.max_len - self.max_video_len - 3, keep_prob
)
self.sampled_min_len = config.sampled_min_len
self.sampled_max_len = config.sampled_max_len
self.masked_token_sampler = TextMaskingProcessor(config)
self.mm_type = config.mm_type \
if config.mm_type is not None else "full"
self.attnmasker = MMAttentionMask2DProcessor() \
if self.mm_type == "textgen" else None
self.masked_frame_sampler = FrameMaskingProcessor(config)
self.lazy_vfeat_mask = (
False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask
)
self.mm_prob = config.mm_prob if config.mm_prob is not None else 0.
def __call__(self, video_id, video_feature, text_feature):
from transformers import default_data_collator
if self.subsampling is not None and self.subsampling > 1:
batch = []
for _ in range(self.subsampling):
centerclip_idx = random.randint(
0, len(text_feature["start"]) - 1)
sampled_max_text_len = random.randint(
self.sampled_min_len, self.sampled_max_len
)
batch.append(
self.sampling(
video_id,
video_feature,
text_feature,
centerclip_idx,
sampled_max_text_len,
)
)
batch = self.batch_post_processing(batch, video_feature)
batch = default_data_collator(batch)
else:
batch = self.sampling(video_id, video_feature, text_feature)
batch = self.batch_post_processing(batch, video_feature)
batch["video_id"] = video_id if isinstance(video_id, str) \
else video_id[0]
return batch
def sampling(
self,
video_id,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
output = FixedLenAligner.sampling(self,
video_id, video_feature, text_feature,
centerclip_idx, sampled_max_text_len)
masking_text, masking_video = None, None
if random.random() < self.mm_prob:
if random.random() > 0.5:
masking_text, masking_video = self.mm_type, "no"
else:
masking_text, masking_video = "no", "full"
video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None
video_label = self.masked_frame_sampler(
output["vmasks"], masking_video, vfeats=video_feats)
caps, text_label = self.masked_token_sampler(
output["caps"], masking_text)
output.update({
"caps": caps,
"video_label": video_label,
"text_label": text_label,
})
if self.attnmasker is not None:
attention_mask = self.attnmasker(
output["vmasks"], output["cmasks"], masking_text)
output.update({
"attention_mask": attention_mask
})
return output
class FrameMaskingProcessor(Processor):
def __init__(self, config):
self.mfm_probability = 0.15
if config.mfm_probability is not None:
self.mfm_probability = config.mfm_probability
def __call__(self, vmasks, modality_masking=None, vfeats=None):
"""
We perform lazy masking to save data transfer time.
It only generates video_labels by default and MFM model
will do actualy masking.
Return: `video_label` is a binary mask.
"""
video_label = vmasks.clone()
if modality_masking is not None:
if modality_masking == "full":
probability_matrix = torch.full(video_label.shape, 1.)
elif modality_masking == "no":
probability_matrix = torch.full(video_label.shape, 0.)
elif modality_masking == "inverse":
probability_matrix = torch.full(
video_label.shape, 1. - self.mfm_probability)
else:
raise ValueError("unknown modality masking.", modality_masking)
else:
probability_matrix = torch.full(
video_label.shape, self.mfm_probability)
masked_indices = torch.bernoulli(probability_matrix).bool()
# We only compute loss on masked tokens
video_label[~masked_indices] = 0
if vfeats is not None:
vfeats[video_label, :] = 0.0
return video_label
class TextGenerationProcessor(Processor):
def __init__(self, tokenizer):
self.bos_token_id = tokenizer.bos_token_id
self.pad_token_id = tokenizer.pad_token_id
def __call__(self, inputs):
labels = inputs.clone()
# [CLS] [SEP] for video
labels[:2] = -100
# keep [SEP] for text.
pad_mask = labels == self.pad_token_id
labels[pad_mask] = -100
inputs[2:] = torch.cat([
torch.LongTensor([self.bos_token_id]),
inputs[2:-1]])
inputs[pad_mask] = self.pad_token_id
assert len(inputs) == len(labels)
return inputs, labels
class TextMaskingProcessor(Processor):
def __init__(self, config):
"""this function is borrowed from
`transformers/data/data_collator.DataCollatorForLanguageModeling`"""
self.mlm_probability = 0.15
if config.mlm_probability is not None:
self.mlm_probability = config.mlm_probability
self.bert_name = config.bert_name
# [CLS] is used as bos_token and [SEP] is used as eos_token.
# https://huggingface.co/transformers/master/model_doc/bertgeneration.html
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.bert_name, bos_token="[CLS]", eos_token="[SEP]")
self.textgen = TextGenerationProcessor(self.tokenizer)
def __call__(
self, inputs: torch.Tensor,
modality_masking=None,
special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
expand modality_masking into
None: traditional bert masking.
"no": no masking.
"full": all [MASK] token for generation.
"gen": autoregressive generation.
"""
"""
Prepare masked tokens inputs/labels for masked language modeling:
80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training
# (with probability `self.mlm_probability`)
if modality_masking is not None:
if modality_masking == "full":
probability_matrix = torch.full(labels.shape, 1.)
elif modality_masking == "no":
probability_matrix = torch.full(labels.shape, 0.)
elif modality_masking.startswith("textgen"):
# [CLS] [SEP] <s> ...
inputs, labels = self.textgen(inputs)
if "mask" not in modality_masking:
return inputs, labels
inputs = self.mask_input(inputs, special_tokens_mask)
return inputs, labels
elif modality_masking == "mask":
inputs = self.mask_input(inputs, special_tokens_mask)
labels = torch.full(inputs.shape, -100)
return inputs, labels
elif modality_masking == "inverse":
probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability)
else:
raise ValueError("unknown modality masking.", modality_masking)
else:
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = self.get_special_tokens_mask(
labels.tolist(), already_has_special_tokens=True
)
special_tokens_mask = torch.tensor(
special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time,
# we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
torch.bernoulli(
torch.full(labels.shape, 0.8)).bool() & masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.mask_token
)
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(
len(self.tokenizer), labels.shape, dtype=torch.long
)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input
# tokens unchanged
return inputs, labels
def mask_input(self, inputs, special_tokens_mask=None):
# the following is new with masked autoregressive.
probability_matrix = torch.full(
inputs.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = self.get_special_tokens_mask(
inputs.tolist(), already_has_special_tokens=True
)
special_tokens_mask = torch.tensor(
special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
indices_replaced = (
torch.bernoulli(
torch.full(inputs.shape, 0.8)).bool() & masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.mask_token
)
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(inputs.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(
len(self.tokenizer), inputs.shape, dtype=torch.long
)
inputs[indices_random] = random_words[indices_random]
return inputs
def get_special_tokens_mask(
self, token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
"""
Note: the version from transformers do not consider pad
as special tokens.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if"
"the provided sequence of "
"ids is already formated with special tokens "
"for the model."
)
return list(map(lambda x: 1 if x in [
self.tokenizer.sep_token_id,
self.tokenizer.cls_token_id,
self.tokenizer.pad_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
class TextClipSamplingProcessor(Processor):
def __init__(self, max_text_len, keep_prob=1.0):
self.max_text_len = max_text_len
self.max_video_len = 256 # always hold.
self.keep_prob = keep_prob
def __call__(
self,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
sampled_max_video_len=None,
):
# Let's use all caps for now and see if 256 can cover all of them.
if sampled_max_text_len is not None:
max_text_len = sampled_max_text_len
else:
max_text_len = self.max_text_len
if sampled_max_video_len is not None:
max_video_len = sampled_max_video_len
else:
max_video_len = self.max_video_len
t_num_clips = len(text_feature["start"])
if centerclip_idx is None:
centerclip_idx = random.randint(0, t_num_clips - 1)
start_idx, end_idx = centerclip_idx, centerclip_idx + 1
text_clip_indexs = deque()
text_clip_indexs.append(start_idx)
text_len = len(text_feature["cap"][start_idx])
video_len = max(
0,
text_feature["end"][start_idx]
- text_feature["start"][start_idx],
)
while (
(start_idx > 0 or end_idx < t_num_clips)
and text_len < max_text_len
and video_len < max_video_len
):
if random.random() > 0.5 and end_idx < t_num_clips:
# skip the next one?
if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
end_idx = end_idx + 1
text_clip_indexs.append(end_idx)
text_len += len(text_feature["cap"][end_idx])
end_idx += 1
elif start_idx > 0:
if random.random() > self.keep_prob and (start_idx - 1) > 0:
start_idx = start_idx - 1
start_idx -= 1
text_clip_indexs.insert(0, start_idx)
text_len += len(text_feature["cap"][start_idx])
else:
if end_idx < t_num_clips:
if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
end_idx = end_idx + 1
text_clip_indexs.append(end_idx)
text_len += len(text_feature["cap"][end_idx])
end_idx += 1
else:
return text_clip_indexs
video_len = max(
0,
text_feature["end"][text_clip_indexs[-1]]
- text_feature["start"][text_clip_indexs[0]],
)
return text_clip_indexs
class VideoClipSamplingProcessor(Processor):
def __call__(self, video_len, max_video_len, center):
"""
`video_len`: length of the video.
`max_video_len`: maximum video tokens allowd in a sequence.
`center`: initial starting index.
"""
assert center >= 0 and center < video_len
t_clip_len = 0
start, end = center, center
while (start > 0 or end < video_len) and t_clip_len < max_video_len:
# decide the direction to grow.
if start <= 0:
end += 1
elif end >= video_len:
start -= 1
elif random.random() > 0.5:
end += 1
else:
start -= 1
t_clip_len += 1
return {"start": [start], "end": [end]}
class How2MILNCEAligner(FixedLenAligner):
"""reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`"""
def __init__(self, config):
super().__init__(config)
self.num_candidates = 4
self.min_time = 5.0
self.num_sec = 3.2
# self.num_sec = self.num_frames / float(self.fps) num_frames=16 / fps = 5
# self.num_frames = 16
def sampling(
self,
video_id,
video_feature,
text_feature,
centerclip_idx=None, # will be ignored.
sampled_max_text_len=None # will be ignored.
):
text, start, end = self._get_text(text_feature)
video = self._get_video(video_feature, start, end)
vfeats = torch.zeros((self.max_video_len, video_feature.shape[1]))
vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
vfeats[: video.shape[0]] = torch.from_numpy(np.array(video))
vmasks[: video.shape[0]] = 1
caps, cmasks = [], []
for words in text:
cap, cmask = self._build_text_seq(text_feature, words)
caps.append(cap)
cmasks.append(cmask)
caps = torch.stack(caps)
cmasks = torch.stack(cmasks)
# video of shape: (video_len)
# text of shape (num_candidates, max_text_len)
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
# "video_id": video_id,
}
def _get_video(self, video_feature, start, end):
start_seek = random.randint(start, int(max(start, end - self.num_sec)))
# duration = self.num_sec + 0.1
return video_feature[start_seek : int(start_seek + self.num_sec)]
def _get_text(self, cap):
ind = random.randint(0, len(cap["start"]) - 1)
if self.num_candidates == 1:
words = [ind]
else:
words = []
cap_start = self._find_nearest_candidates(cap, ind)
for i in range(self.num_candidates):
words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))])
start, end = cap["start"][ind], cap["end"][ind]
# TODO: May need to be improved for edge cases.
# expand the min time.
if end - start < self.min_time:
diff = self.min_time - end + start
start = max(0, start - diff / 2)
end = start + self.min_time
return words, int(start), int(end)
def _find_nearest_candidates(self, caption, ind):
"""find the range of the clips."""
start, end = ind, ind
#diff = caption["end"][end] - caption["start"][start]
n_candidate = 1
while n_candidate < self.num_candidates:
# the first clip
if start == 0:
return 0
# we add () in the following condition to fix the bug.
elif end == (len(caption["start"]) - 1):
return start - (self.num_candidates - n_candidate)
elif (caption["end"][end] - caption["start"][start - 1]) < (
caption["end"][end + 1] - caption["start"][start]
):
start -= 1
else:
end += 1
n_candidate += 1
return start
class PKLJSONStrTextProcessor(TextProcessor):
"""`caption.json` from howto100m are preprocessed as a
dict `[video_id, json_str]`.
Json parsing tokenization are conducted on-the-fly and cached into dict.
"""
def __init__(self, config, max_clip_text_len=96):
print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.")
self.caption_pkl_path = str(config.caption_pkl_path)
with open(self.caption_pkl_path, "rb") as fd:
self.data = pickle.load(fd)
self.max_clip_text_len = max_clip_text_len
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
str(config.bert_name), use_fast=config.use_fast
)
def __call__(self, video_id):
caption = self.data[video_id]
if isinstance(caption, str):
import json
caption = json.loads(caption)
cap = []
for clip_idx, text_clip in enumerate(caption["text"]):
clip_ids = []
if isinstance(text_clip, str):
clip_ids = self.tokenizer(
text_clip[: self.max_clip_text_len],
add_special_tokens=False
)["input_ids"]
cap.append(clip_ids)
caption["cap"] = cap
caption.pop("text") # save space.
self.data[video_id] = caption
return caption
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .how2processor import (
ShardedHow2MetaProcessor,
ShardedVideoProcessor,
ShardedTextProcessor,
VariedLenAligner,
OverlappedAligner
)
class ShardedHow2VideoRetriMetaProcessor(ShardedHow2MetaProcessor):
def __init__(self, config):
super().__init__(config)
self.num_video_per_batch = config.num_video_per_batch
self.cands = [
self.data[batch_offset:batch_offset + self.num_video_per_batch]
for batch_offset in
range(0, (len(self.data) // (8 * self.num_video_per_batch)) * 8 * self.num_video_per_batch, self.num_video_per_batch)]
def __len__(self):
return len(self.cands)
def set_candidates(self, cands):
# no changes on num of batches.
print(len(self.cands), "->", len(cands))
# assert len(self.cands) == len(cands)
self.cands = cands
def __getitem__(self, idx):
video_ids = self.cands[idx]
assert isinstance(video_ids, list)
sharded_video_idxs = []
for video_id in video_ids:
shard_id, video_idx = self.video_id_to_shard[video_id]
sharded_video_idxs.append((video_id, -1, shard_id, video_idx))
return sharded_video_idxs, sharded_video_idxs
class ShardedVideoRetriVideoProcessor(ShardedVideoProcessor):
"""In retrival case the video_id
is a list of tuples: `(shard_id, video_idx)` ."""
def __call__(self, sharded_video_idxs):
assert isinstance(sharded_video_idxs, list)
cand_feats = []
for shared_video_idx in sharded_video_idxs:
feat = super().__call__(shared_video_idx)
cand_feats.append(feat)
return cand_feats
class ShardedVideoRetriTextProcessor(ShardedTextProcessor):
"""In retrival case the video_id
is a list of tuples: `(shard_id, video_idx)` ."""
def __call__(self, sharded_video_idxs):
assert isinstance(sharded_video_idxs, list)
cand_caps = []
for shared_video_idx in sharded_video_idxs:
caps = super().__call__(shared_video_idx)
cand_caps.append(caps)
return cand_caps
class VideoRetriAligner(VariedLenAligner):
# Retritask will trim dim-0.
def __call__(self, sharded_video_idxs, video_features, text_features):
from transformers import default_data_collator
batch, video_ids = [], []
for video_id, video_feature, text_feature in \
zip(sharded_video_idxs, video_features, text_features):
sub_batch = super().__call__(video_id, video_feature, text_feature)
batch.append(sub_batch)
if isinstance(video_id, tuple):
video_id = video_id[0]
video_ids.append(video_id)
batch = default_data_collator(batch)
batch["video_id"] = video_ids
return batch
class VideoRetriOverlappedAligner(OverlappedAligner):
# Retritask will trim dim-0.
def __call__(self, sharded_video_idxs, video_features, text_features):
from transformers import default_data_collator
batch, video_ids = [], []
for video_id, video_feature, text_feature in \
zip(sharded_video_idxs, video_features, text_features):
sub_batch = super().__call__(video_id, video_feature, text_feature)
batch.append(sub_batch)
if isinstance(video_id, tuple):
video_id = video_id[0]
video_ids.append(video_id)
batch = default_data_collator(batch)
batch["video_id"] = video_ids
return batch
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Contains a PyTorch definition for Gated Separable 3D network (S3D-G)
with a text module for computing joint text-video embedding from raw text
and video input. The following code will enable you to load the HowTo100M
pretrained S3D Text-Video model from:
A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman,
End-to-End Learning of Visual Representations from Uncurated Instructional Videos.
https://arxiv.org/abs/1912.06430.
S3D-G was proposed by:
S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy,
Rethinking Spatiotemporal Feature Learning For Video Understanding.
https://arxiv.org/abs/1712.04851.
Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
The S3D architecture was slightly modified with a space to depth trick for TPU
optimization.
"""
import torch as th
import torch.nn.functional as F
import torch.nn as nn
import os
import numpy as np
import re
class InceptionBlock(nn.Module):
def __init__(
self,
input_dim,
num_outputs_0_0a,
num_outputs_1_0a,
num_outputs_1_0b,
num_outputs_2_0a,
num_outputs_2_0b,
num_outputs_3_0b,
gating=True,
):
super(InceptionBlock, self).__init__()
self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1])
self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1])
self.conv_b1_b = STConv3D(
num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True
)
self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1])
self.conv_b2_b = STConv3D(
num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True
)
self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1)
self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1])
self.gating = gating
self.output_dim = (
num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b
)
if gating:
self.gating_b0 = SelfGating(num_outputs_0_0a)
self.gating_b1 = SelfGating(num_outputs_1_0b)
self.gating_b2 = SelfGating(num_outputs_2_0b)
self.gating_b3 = SelfGating(num_outputs_3_0b)
def forward(self, input):
"""Inception block
"""
b0 = self.conv_b0(input)
b1 = self.conv_b1_a(input)
b1 = self.conv_b1_b(b1)
b2 = self.conv_b2_a(input)
b2 = self.conv_b2_b(b2)
b3 = self.maxpool_b3(input)
b3 = self.conv_b3_b(b3)
if self.gating:
b0 = self.gating_b0(b0)
b1 = self.gating_b1(b1)
b2 = self.gating_b2(b2)
b3 = self.gating_b3(b3)
return th.cat((b0, b1, b2, b3), dim=1)
class SelfGating(nn.Module):
def __init__(self, input_dim):
super(SelfGating, self).__init__()
self.fc = nn.Linear(input_dim, input_dim)
def forward(self, input_tensor):
"""Feature gating as used in S3D-G.
"""
spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4])
weights = self.fc(spatiotemporal_average)
weights = th.sigmoid(weights)
return weights[:, :, None, None, None] * input_tensor
class STConv3D(nn.Module):
def __init__(
self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False
):
super(STConv3D, self).__init__()
self.separable = separable
self.relu = nn.ReLU(inplace=True)
assert len(kernel_size) == 3
if separable and kernel_size[0] != 1:
spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
temporal_kernel_size = [kernel_size[0], 1, 1]
if isinstance(stride, list) and len(stride) == 3:
spatial_stride = [1, stride[1], stride[2]]
temporal_stride = [stride[0], 1, 1]
else:
spatial_stride = [1, stride, stride]
temporal_stride = [stride, 1, 1]
if isinstance(padding, list) and len(padding) == 3:
spatial_padding = [0, padding[1], padding[2]]
temporal_padding = [padding[0], 0, 0]
else:
spatial_padding = [0, padding, padding]
temporal_padding = [padding, 0, 0]
if separable:
self.conv1 = nn.Conv3d(
input_dim,
output_dim,
kernel_size=spatial_kernel_size,
stride=spatial_stride,
padding=spatial_padding,
bias=False,
)
self.bn1 = nn.BatchNorm3d(output_dim)
self.conv2 = nn.Conv3d(
output_dim,
output_dim,
kernel_size=temporal_kernel_size,
stride=temporal_stride,
padding=temporal_padding,
bias=False,
)
self.bn2 = nn.BatchNorm3d(output_dim)
else:
self.conv1 = nn.Conv3d(
input_dim,
output_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
)
self.bn1 = nn.BatchNorm3d(output_dim)
def forward(self, input):
out = self.relu(self.bn1(self.conv1(input)))
if self.separable:
out = self.relu(self.bn2(self.conv2(out)))
return out
class MaxPool3dTFPadding(th.nn.Module):
def __init__(self, kernel_size, stride=None, padding="SAME"):
super(MaxPool3dTFPadding, self).__init__()
if padding == "SAME":
padding_shape = self._get_padding_shape(kernel_size, stride)
self.padding_shape = padding_shape
self.pad = th.nn.ConstantPad3d(padding_shape, 0)
self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True)
def _get_padding_shape(self, filter_shape, stride):
def _pad_top_bottom(filter_dim, stride_val):
pad_along = max(filter_dim - stride_val, 0)
pad_top = pad_along // 2
pad_bottom = pad_along - pad_top
return pad_top, pad_bottom
padding_shape = []
for filter_dim, stride_val in zip(filter_shape, stride):
pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val)
padding_shape.append(pad_top)
padding_shape.append(pad_bottom)
depth_top = padding_shape.pop(0)
depth_bottom = padding_shape.pop(0)
padding_shape.append(depth_top)
padding_shape.append(depth_bottom)
return tuple(padding_shape)
def forward(self, inp):
inp = self.pad(inp)
out = self.pool(inp)
return out
class Sentence_Embedding(nn.Module):
def __init__(
self,
embd_dim,
num_embeddings=66250,
word_embedding_dim=300,
token_to_word_path="dict.npy",
max_words=16,
output_dim=2048,
):
super(Sentence_Embedding, self).__init__()
self.word_embd = nn.Embedding(num_embeddings, word_embedding_dim)
self.fc1 = nn.Linear(word_embedding_dim, output_dim)
self.fc2 = nn.Linear(output_dim, embd_dim)
self.word_to_token = {}
self.max_words = max_words
token_to_word = np.load(token_to_word_path)
for i, t in enumerate(token_to_word):
self.word_to_token[t] = i + 1
def _zero_pad_tensor_token(self, tensor, size):
if len(tensor) >= size:
return tensor[:size]
else:
zero = th.zeros(size - len(tensor)).long()
return th.cat((tensor, zero), dim=0)
def _split_text(self, sentence):
w = re.findall(r"[\w']+", str(sentence))
return w
def _words_to_token(self, words):
words = [
self.word_to_token[word] for word in words if word in self.word_to_token
]
if words:
we = self._zero_pad_tensor_token(th.LongTensor(words), self.max_words)
return we
else:
return th.zeros(self.max_words).long()
def _words_to_ids(self, x):
split_x = [self._words_to_token(self._split_text(sent.lower())) for sent in x]
return th.stack(split_x, dim=0)
def forward(self, x):
x = self._words_to_ids(x)
x = self.word_embd(x)
x = F.relu(self.fc1(x))
x = th.max(x, dim=1)[0]
x = self.fc2(x)
return {'text_embedding': x}
class S3D(nn.Module):
def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True):
super(S3D, self).__init__()
self.num_classes = num_classes
self.gating = gating
self.space_to_depth = space_to_depth
if space_to_depth:
self.conv1 = STConv3D(
24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False
)
else:
self.conv1 = STConv3D(
3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False
)
self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False)
self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True)
self.gating = SelfGating(192)
self.maxpool_2a = MaxPool3dTFPadding(
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
)
self.maxpool_3a = MaxPool3dTFPadding(
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
)
self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
self.mixed_3c = InceptionBlock(
self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64
)
self.maxpool_4a = MaxPool3dTFPadding(
kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME"
)
self.mixed_4b = InceptionBlock(
self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64
)
self.mixed_4c = InceptionBlock(
self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64
)
self.mixed_4d = InceptionBlock(
self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64
)
self.mixed_4e = InceptionBlock(
self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64
)
self.mixed_4f = InceptionBlock(
self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128
)
self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding(
kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME"
)
self.mixed_5b = InceptionBlock(
self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128
)
self.mixed_5c = InceptionBlock(
self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128
)
self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes)
self.text_module = Sentence_Embedding(num_classes,
token_to_word_path=dict_path)
def _space_to_depth(self, input):
"""3D space to depth trick for TPU optimization.
"""
B, C, T, H, W = input.shape
input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2)
input = input.permute(0, 3, 5, 7, 1, 2, 4, 6)
input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2)
return input
def forward(self, inputs):
"""Defines the S3DG base architecture."""
if self.space_to_depth:
inputs = self._space_to_depth(inputs)
net = self.conv1(inputs)
if self.space_to_depth:
# we need to replicate 'SAME' tensorflow padding
net = net[:, :, 1:, 1:, 1:]
net = self.maxpool_2a(net)
net = self.conv_2b(net)
net = self.conv_2c(net)
if self.gating:
net = self.gating(net)
net = self.maxpool_3a(net)
net = self.mixed_3b(net)
net = self.mixed_3c(net)
net = self.maxpool_4a(net)
net = self.mixed_4b(net)
net = self.mixed_4c(net)
net = self.mixed_4d(net)
net = self.mixed_4e(net)
net = self.mixed_4f(net)
net = self.maxpool_5a(net)
net = self.mixed_5b(net)
net = self.mixed_5c(net)
net = th.mean(net, dim=[2, 3, 4])
return {'video_embedding': self.fc(net), 'mixed_5c': net}
# Copyright (c) Facebook, Inc. All Rights Reserved
import numpy as np
import os
import torch
class Processor(object):
"""
A generic processor for video (codec, feature etc.) and text.
"""
def __call__(self, **kwargs):
raise NotImplementedError
class MetaProcessor(Processor):
"""
A meta processor is expected to load the metadata of a dataset:
(e.g., video_ids, or captions).
You must implement the `__getitem__` (meta datasets are rather diverse.).
"""
def __init__(self, config):
self.split = config.split
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
raise NotImplementedError
def _get_split_path(self, config):
splits = {
"train": config.train_path,
"valid": config.val_path,
"test": config.test_path,
}
if config.split is not None:
return splits[config.split]
return config.train_path
class TextProcessor(Processor):
"""
A generic Text processor: rename this as `withTokenizer`.
tokenize a string of text on-the-fly.
Warning: mostly used for end tasks.
(on-the-fly tokenization is slow for how2.)
TODO(huxu): move this class as a subclass.
"""
def __init__(self, config):
self.bert_name = str(config.bert_name)
self.use_fast = config.use_fast
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.bert_name, use_fast=self.use_fast
)
def __call__(self, text_id):
caption = self.tokenizer(text_id, add_special_tokens=False)
return caption["input_ids"]
class VideoProcessor(Processor):
"""
A generic video processor: load a numpy video tokens by default.
"""
def __init__(self, config):
self.vfeat_dir = config.vfeat_dir
def __call__(self, video_fn):
if isinstance(video_fn, tuple):
video_fn = video_fn[0]
assert isinstance(video_fn, str)
video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy")
feat = np.load(video_fn)
return feat
class Aligner(object):
"""
An alignprocessor align video and text and output a dict of tensors (for a model).
"""
def __init__(self, config):
"""__init__ needs to be light weight for more workers/threads."""
self.split = config.split
self.max_video_len = config.max_video_len
self.max_len = config.max_len
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
str(config.bert_name), use_fast=config.use_fast
)
self.cls_token_id = tokenizer.cls_token_id
self.sep_token_id = tokenizer.sep_token_id
self.pad_token_id = tokenizer.pad_token_id
self.mask_token_id = tokenizer.mask_token_id
def __call__(self, video_id, video_feature, text_feature):
raise NotImplementedError
def _build_video_seq(self, video_feature, video_clips=None):
"""
`video_feature`: available video tokens.
`video_clips`: video clip sequence to build.
"""
if not isinstance(video_feature, np.ndarray):
raise ValueError(
"unsupported type of video_feature", type(video_feature)
)
if video_clips is None:
# this is borrowed from DSAligner
video_start = 0
video_end = min(len(video_feature), self.max_video_len)
# the whole sequence is a single clip.
video_clips = {"start": [video_start], "end": [video_end]}
vfeats = np.zeros(
(self.max_video_len, video_feature.shape[1]), dtype=np.float32
)
vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
video_len = 0
for start, end in zip(video_clips["start"], video_clips["end"]):
clip_len = min(self.max_video_len - video_len, (end - start))
if clip_len > 0:
vfeats[video_len: video_len + clip_len] = video_feature[
start: start + clip_len
]
vmasks[video_len: video_len + clip_len] = 1
video_len += clip_len
vfeats = torch.from_numpy(vfeats)
return vfeats, vmasks
def _build_text_seq(self, text_feature, text_clip_indexs=None):
"""
`text_feature`: all available clips.
`text_clip_indexes`: clip sequence to build.
"""
if text_clip_indexs is None:
text_clip_indexs = [0]
full_caps = []
if isinstance(text_feature, dict):
for clip_idx in text_clip_indexs:
full_caps.extend(text_feature["cap"][clip_idx])
else:
full_caps = text_feature
max_text_len = self.max_len - self.max_video_len - 3
full_caps = full_caps[:max_text_len]
full_caps = (
[self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id]
)
text_pad_len = self.max_len - len(full_caps) - self.max_video_len
padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len
caps = torch.LongTensor(padded_full_caps)
cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool)
cmasks[: len(full_caps)] = 1
return caps, cmasks
def batch_post_processing(self, batch, video_feature):
return batch
class MMAttentionMask2DProcessor(Processor):
"""text generation requires 2d mask
that is harder to generate by GPU at this stage."""
def __call__(self, vmask, cmask, mtype):
if mtype == "textgen":
return self._build_textgeneration_mask(vmask, cmask)
elif mtype == "videogen":
return self._build_videogeneration_mask(vmask, cmask)
else:
return self._build_mm_mask(vmask, cmask)
def _build_mm_mask(self, vmask, cmask):
mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)
return mask_1d[None, :].repeat(mask_1d.size(0), 1)
def _build_videogeneration_mask(self, vmask, cmask):
# cls_mask is only about text otherwise it will leak generation.
cls_text_mask = torch.cat([
# [CLS]
torch.ones(
(1,), dtype=torch.bool, device=cmask.device),
# video tokens and [SEP] for video.
torch.zeros(
(vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device),
cmask[2:]
], dim=0)
# concat horizontially.
video_len = int(vmask.sum())
video_masks = torch.cat([
# [CLS]
torch.ones(
(video_len, 1), dtype=torch.bool, device=cmask.device
),
torch.tril(
torch.ones(
(video_len, video_len),
dtype=torch.bool, device=cmask.device)),
# video_padding
torch.zeros(
(video_len, vmask.size(0) - video_len),
dtype=torch.bool, device=cmask.device
),
# [SEP] for video (unused).
torch.zeros(
(video_len, 1), dtype=torch.bool, device=cmask.device
),
cmask[2:].unsqueeze(0).repeat(video_len, 1)
], dim=1)
text_masks = cls_text_mask[None, :].repeat(
cmask.size(0) - 2, 1)
video_padding_masks = cls_text_mask[None, :].repeat(
vmask.size(0) - video_len, 1)
return torch.cat([
cls_text_mask[None, :],
video_masks,
video_padding_masks,
torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:],
text_masks
], dim=0)
def _build_textgeneration_mask(self, vmask, cmask):
# cls_mask is only about video otherwise it will leak generation.
cls_video_mask = torch.cat([
# [CLS]
torch.ones(
(1,), dtype=torch.bool, device=cmask.device),
vmask,
# [SEP]
torch.ones((1,), dtype=torch.bool, device=cmask.device),
torch.zeros(
(cmask.size(0)-2,), dtype=torch.bool, device=cmask.device)
], dim=0)
# concat horizontially.
text_len = int(cmask[2:].sum())
text_masks = torch.cat([
# [CLS]
torch.ones(
(text_len, 1), dtype=torch.bool, device=cmask.device
),
vmask.unsqueeze(0).repeat(text_len, 1),
# [SEP] for video.
torch.ones(
(text_len, 1), dtype=torch.bool, device=cmask.device
),
torch.tril(
torch.ones(
(text_len, text_len),
dtype=torch.bool, device=cmask.device)),
# padding.
torch.zeros(
(text_len, cmask.size(0) - text_len - 2),
dtype=torch.bool, device=cmask.device
)
], dim=1)
cls_video_masks = cls_video_mask[None, :].repeat(
vmask.size(0) + 2, 1)
text_padding_masks = cls_video_mask[None, :].repeat(
cmask.size(0) - text_len - 2, 1)
return torch.cat([
cls_video_masks, text_masks, text_padding_masks], dim=0)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .task import *
from .vlmtask import *
from .retritask import *
try:
from .fairseqmmtask import *
except ImportError:
pass
try:
from .milncetask import *
except ImportError:
pass
try:
from .expretritask import *
except ImportError:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment