Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
make a general fairseq task for MM pretraining.
"""
import random
from fairseq.tasks import LegacyFairseqTask, register_task
from .task import Task
from .retritask import RetriTask
from ..datasets import FairseqMMDataset
from .. import utils
@register_task("mmtask")
class FairseqMMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
# Add some command-line arguments for specifying where the data is
# located and the maximum supported input length.
parser.add_argument(
"taskconfig",
metavar="FILE",
help=("taskconfig to load all configurations" "outside fairseq parser."),
)
@classmethod
def setup_task(cls, args, **kwargs):
return FairseqMMTask(args)
def __init__(self, args):
super().__init__(args)
config = utils.load_config(args)
self.mmtask = Task.config_task(config)
self.mmtask.build_dataset()
self.mmtask.build_model()
self.mmtask.build_loss()
def load_dataset(self, split, **kwargs):
split_map = {
"train": self.mmtask.train_data,
"valid": self.mmtask.val_data,
"test": self.mmtask.test_data,
}
if split not in split_map:
raise ValueError("unknown split type.")
if split_map[split] is not None:
self.datasets[split] = FairseqMMDataset(split_map[split])
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
random.seed(epoch)
if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask):
if epoch >= self.mmtask.config.retri_epoch:
if not hasattr(self.mmtask, "retri_dataloader"):
self.mmtask.build_dataloader()
self.mmtask.retrive_candidates(epoch)
return super().get_batch_iterator(
dataset,
max_tokens,
max_sentences,
max_positions,
ignore_invalid_inputs,
required_batch_size_multiple,
seed,
num_shards,
shard_id,
num_workers,
epoch,
data_buffer_size,
disable_iterator_cache,
grouped_shuffling,
update_epoch_batch_itr,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .task import Task
class MILNCETask(Task):
def reshape_subsample(self, sample):
if (
hasattr(self.config.dataset, "subsampling")
and self.config.dataset.subsampling is not None
and self.config.dataset.subsampling > 1
):
for key in sample:
if torch.is_tensor(sample[key]):
tensor = self.flat_subsample(sample[key])
if key in ["caps", "cmasks"]:
size = tensor.size()
batch_size = size[0] * size[1]
expanded_size = (batch_size,) + size[2:]
tensor = tensor.view(expanded_size)
sample[key] = tensor
return sample
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import pickle
import random
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ..processors import (
ShardedHow2MetaProcessor,
ShardedVideoProcessor,
ShardedTextProcessor,
VariedLenAligner,
)
from ..datasets import MMDataset
from .task import Task
from ..modules import vectorpool
from ..evaluators.predictor import Predictor
from ..utils import set_seed, get_local_rank, get_world_size
class RetriTask(Task):
"""abstract class for task with retrival."""
def reshape_subsample(self, sample):
for key in sample:
if torch.is_tensor(sample[key]):
sample[key] = self.flat_subsample(sample[key])
return sample
def flat_subsample(self, tensor):
if tensor.size(0) == 1:
tensor = tensor.squeeze(0)
return tensor
def build_dataloader(self):
"""called by `get_batch_iterator` in fairseqmmtask. """
# TODO: hard-code dataloader for retri for now and configurable in .yaml.
# reuse the `train.lst`.
self.config.dataset.split = "train"
meta_processor = ShardedHow2MetaProcessor(self.config.dataset)
video_processor = ShardedVideoProcessor(self.config.dataset)
text_processor = ShardedTextProcessor(self.config.dataset)
aligner = VariedLenAligner(self.config.dataset)
aligner.subsampling = self.config.dataset.clip_per_video
self.retri_data = MMDataset(
meta_processor, video_processor, text_processor, aligner
)
retri_sampler = DistributedSampler(self.retri_data)
infer_scale = 16
batch_size = self.config.dataset.num_video_per_batch \
* infer_scale
self.retri_dataloader = DataLoader(
self.retri_data,
collate_fn=self.retri_data.collater,
batch_size=batch_size,
shuffle=False,
sampler=retri_sampler,
num_workers=self.config.fairseq.dataset.num_workers
)
return self.retri_dataloader
def retrive_candidates(self, epoch, dataloader=None):
if get_local_rank() == 0:
print("running retrieval model.")
out_dir = os.path.join(
self.config.fairseq.checkpoint.save_dir, "retri")
os.makedirs(out_dir, exist_ok=True)
if not os.path.isfile(
os.path.join(
out_dir, "batched_e" + str(epoch) + "_videos0.pkl")
):
if dataloader is None:
dataloader = self.retri_dataloader
self.model.eval()
self.model.is_train = False
assert self.retri_data.meta_processor.data == \
self.train_data.meta_processor.data # video_ids not mutated.
self._retri_predict(epoch, dataloader)
self.model.train()
self.model.is_train = True
torch.distributed.barrier()
output = self._retri_sync(epoch, out_dir)
torch.distributed.barrier()
self.train_data.meta_processor.set_candidates(output)
return output
class VideoRetriTask(RetriTask):
"""RetriTask on video level."""
def reshape_subsample(self, sample):
if (
hasattr(self.config.dataset, "clip_per_video")
and self.config.dataset.clip_per_video is not None
and self.config.dataset.clip_per_video > 1
):
for key in sample:
if torch.is_tensor(sample[key]):
sample[key] = self.flat_subsample(sample[key])
return sample
def flat_subsample(self, tensor):
if tensor.size(0) == 1:
tensor = tensor.squeeze(0)
return Task.flat_subsample(self, tensor)
def _retri_predict(self, epoch, dataloader):
set_seed(epoch)
# save for retrival.
predictor = VideoPredictor(self.config)
predictor.predict_loop(
self.model, dataloader)
set_seed(epoch) # get the same text clips.
# retrival.
retri_predictor = VideoRetriPredictor(
self.config)
retri_predictor.predict_loop(
self.model, predictor.vecpool.retriver, epoch)
del predictor
del retri_predictor
def _retri_sync(self, epoch, out_dir):
# gpu do the same merge.
batched_videos = []
for local_rank in range(get_world_size()):
fn = os.path.join(
out_dir,
"batched_e" + str(epoch) + "_videos" + str(local_rank) + ".pkl")
with open(fn, "rb") as fr:
batched_videos.extend(pickle.load(fr))
print(
"[INFO] batched_videos",
len(batched_videos), len(batched_videos[0]))
return batched_videos
class VideoPredictor(Predictor):
def __init__(self, config):
vectorpool_cls = getattr(vectorpool, config.vectorpool_cls)
self.vecpool = vectorpool_cls(config)
def predict_loop(
self,
model,
dataloader,
early_stop=-1,
):
with torch.no_grad():
if get_local_rank() == 0:
dataloader = tqdm(dataloader)
for batch_idx, batch in enumerate(dataloader):
if batch_idx == early_stop:
break
self(batch, model)
return self.finalize()
def __call__(self, sample, model, **kwargs):
param = next(model.parameters())
dtype = param.dtype
device = param.device
subsample = sample["vfeats"].size(1)
sample = self.to_ctx(sample, device, dtype)
for key in sample:
if torch.is_tensor(sample[key]):
size = sample[key].size()
if len(size) >= 2:
batch_size = size[0] * size[1]
expanded_size = (
(batch_size,) + size[2:] if len(size) > 2
else (batch_size,)
)
sample[key] = sample[key].view(expanded_size)
outputs = model(**sample)
sample.update(outputs)
self.vecpool(sample, subsample)
def finalize(self):
print("[INFO]", self.vecpool)
if not self.vecpool.retriver.db.is_trained:
self.vecpool.retriver.finalize_training()
return self.vecpool.retriver
class VideoRetriPredictor(Predictor):
"""
Online Retrieval Predictor for Clips (used by RetriTask).
TODO: merge this with VisPredictor?
"""
def __init__(self, config):
self.pred_dir = os.path.join(
config.fairseq.checkpoint.save_dir,
"retri")
self.num_cands = config.num_cands
self.num_video_per_batch = config.dataset.num_video_per_batch
def predict_loop(
self,
model,
retriver,
epoch,
early_stop=-1
):
# a fake loop that only try to recover video vector
# from video_id.
batched_videos = []
# obtain available video_ids.
video_ids = list(retriver.videoid_to_vectoridx.keys())
dataloader = random.sample(
video_ids,
len(video_ids) // self.num_video_per_batch
)
if get_local_rank() == 0:
dataloader = tqdm(dataloader)
for batch_idx, batch in enumerate(dataloader):
# batch is one video id.
if batch_idx == early_stop:
break
video_ids = retriver.search_by_video_ids(
[batch], self.num_cands)[0]
if len(video_ids) > self.num_video_per_batch:
# we moved the center to make cluster robust.
video_ids = random.sample(video_ids, self.num_video_per_batch)
batched_videos.append(video_ids)
return self.finalize(batched_videos, epoch)
def finalize(self, batched_videos, epoch):
fn = os.path.join(
self.pred_dir,
"batched_e" + str(epoch) + "_videos" + str(get_local_rank()) + ".pkl")
with open(fn, "wb") as fw:
pickle.dump(batched_videos, fw, pickle.HIGHEST_PROTOCOL)
return batched_videos
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .. import tasks
from .. import models
from .. import losses
from ..datasets import MMDataset
from .. import processors
class Task(object):
"""
A task refers to one generic training task (e.g., training one model).
"""
@classmethod
def config_task(cls, config):
"""
determine whether to load a hard-coded task or config from a generic one.
via if a task string is available in config.
"""
if config.task is not None:
# TODO (huxu): expand the search scope.
task_cls = getattr(tasks, config.task)
return task_cls(config)
else:
return Task(config)
def __init__(self, config):
self.config = config
self.train_data = None
self.val_data = None
self.test_data = None
self.model = None
self.loss_fn = None
self.eval_fn = None
def build_dataset(self):
"""TODO (huxu): move processor breakdown to MMDataset."""
"""fill-in `self.train_data`, `self.val_data` and `self.test_data`."""
meta_processor_cls = getattr(
processors, self.config.dataset.meta_processor)
video_processor_cls = getattr(
processors, self.config.dataset.video_processor)
text_processor_cls = getattr(
processors, self.config.dataset.text_processor)
aligner_cls = getattr(
processors, self.config.dataset.aligner)
if self.config.dataset.train_path is not None:
self.config.dataset.split = "train"
# may be used by meta processor.
# meta_processor controls different dataset.
meta_processor = meta_processor_cls(self.config.dataset)
video_processor = video_processor_cls(self.config.dataset)
text_processor = text_processor_cls(self.config.dataset)
aligner = aligner_cls(self.config.dataset)
self.train_data = MMDataset(
meta_processor, video_processor, text_processor, aligner
)
print("train_len", len(self.train_data))
output = self.train_data[0]
self.train_data.print_example(output)
if self.config.dataset.val_path is not None:
self.config.dataset.split = "valid"
# may be used by meta processor.
meta_processor = meta_processor_cls(self.config.dataset)
video_processor = video_processor_cls(self.config.dataset)
text_processor = text_processor_cls(self.config.dataset)
aligner = aligner_cls(self.config.dataset)
self.val_data = MMDataset(
meta_processor, video_processor, text_processor, aligner
)
print("val_len", len(self.val_data))
output = self.val_data[0]
self.val_data.print_example(output)
if self.config.dataset.split == "test":
# the following is run via lauching fairseq-validate.
meta_processor = meta_processor_cls(self.config.dataset)
video_processor = video_processor_cls(self.config.dataset)
text_processor = text_processor_cls(self.config.dataset)
self.test_data = MMDataset(
meta_processor, video_processor, text_processor, aligner
)
print("test_len", len(self.test_data))
output = self.test_data[0]
self.test_data.print_example(output)
def build_model(self, checkpoint=None):
if self.model is None:
model_cls = getattr(models, self.config.model.model_cls)
self.model = model_cls(self.config)
if checkpoint is not None:
self.load_checkpoint(checkpoint)
return self.model
def load_checkpoint(self, checkpoint):
if self.model is None:
raise ValueError("model is not initialized.")
state_dict = torch.load(checkpoint)
state_dict = self._trim_state_dict(state_dict)
self.model.load_state_dict(state_dict, strict=False)
# if it's a fp16 model, turn it back.
if next(self.model.parameters()).dtype == torch.float16:
self.model = self.model.float()
return self.model
def _trim_state_dict(self, state_dict):
from collections import OrderedDict
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if "model" in state_dict: # fairseq checkpoint format.
state_dict = state_dict["model"]
ret_state_dict = OrderedDict()
for (
key,
value,
) in state_dict.items():
# remove fairseq wrapper since this is a task.
if key.startswith("mmmodel"):
key = key[len("mmmodel."):]
ret_state_dict[key] = value
return ret_state_dict
def build_loss(self):
if self.loss_fn is None and self.config.loss is not None:
loss_cls = getattr(losses, self.config.loss.loss_cls)
self.loss_fn = loss_cls()
return self.loss_fn
def flat_subsample(self, tensor):
size = tensor.size()
if len(size) >= 2:
batch_size = size[0] * size[1]
expanded_size = (
(batch_size,) + size[2:] if len(size) > 2
else (batch_size,)
)
tensor = tensor.view(expanded_size)
return tensor
def reshape_subsample(self, sample):
if (
hasattr(self.config.dataset, "subsampling")
and self.config.dataset.subsampling is not None
and self.config.dataset.subsampling > 1
):
for key in sample:
if torch.is_tensor(sample[key]):
sample[key] = self.flat_subsample(sample[key])
return sample
def __call__(self, model, sample):
loss = None
loss_scalar = float("inf")
sample = self.reshape_subsample(sample)
outputs = self.model(**sample)
sample.update(outputs)
if self.loss_fn is not None:
loss = self.loss_fn(**sample)
loss_scalar = loss.item()
batch_size = sample["caps"].size(0)
sample_size = 1
return {
"loss": loss,
"loss_scalar": loss_scalar,
"max_len": self.config.dataset.max_len,
"batch_size": batch_size,
"sample_size": sample_size,
}
def build_dataloader(self):
"""only used for trainer that lacks building loaders."""
raise NotImplementedError
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .task import Task
class VLMTask(Task):
"""A VLM task for reproducibility.
the collator split subsamples into two sub-batches.
This has should have no logic changes.
but changed the randomness in frame masking.
"""
def flat_subsample(self, tensor):
size = tensor.size()
if len(size) >= 2:
batch_size = size[0] * (size[1] // 2)
expanded_size = (
(batch_size, 2) + size[2:] if len(size) > 2
else (batch_size, 2)
)
tensor = tensor.view(expanded_size)
tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0)
return tensor
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import torch
from .shardedtensor import *
from .load_config import *
def set_seed(seed=43211):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.cudnn.enabled:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_world_size():
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
def get_local_rank():
return torch.distributed.get_rank() \
if torch.distributed.is_initialized() else 0
def print_on_rank0(func):
local_rank = get_local_rank()
if local_rank == 0:
print("[INFO]", func)
class RetriMeter(object):
"""
Statistics on whether retrieval yields a better pair.
"""
def __init__(self, freq=1024):
self.freq = freq
self.total = 0
self.replace = 0
self.updates = 0
def __call__(self, data):
if isinstance(data, np.ndarray):
self.replace += data.shape[0] - int((data[:, 0] == -1).sum())
self.total += data.shape[0]
elif torch.is_tensor(data):
self.replace += int(data.sum())
self.total += data.size(0)
else:
raise ValueError("unsupported RetriMeter data type.", type(data))
self.updates += 1
if get_local_rank() == 0 and self.updates % self.freq == 0:
print("[INFO]", self)
def __repr__(self):
return "RetriMeter (" + str(self.replace / self.total) \
+ "/" + str(self.replace) + "/" + str(self.total) + ")"
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import omegaconf
from omegaconf import OmegaConf
def load_config(args=None, config_file=None, overwrite_fairseq=False):
"""TODO (huxu): move fairseq overwrite to another function."""
if args is not None:
config_file = args.taskconfig
config = recursive_config(config_file)
if config.dataset.subsampling is not None:
batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling
print(
"adjusting batch_size to {} due to subsampling {}.".format(
batch_size, config.dataset.subsampling
)
)
config.fairseq.dataset.batch_size = batch_size
is_test = config.dataset.split is not None and config.dataset.split == "test"
if not is_test:
if (
config.fairseq.checkpoint is None
or config.fairseq.checkpoint.save_dir is None
):
raise ValueError("fairseq save_dir or save_path must be specified.")
save_dir = config.fairseq.checkpoint.save_dir
os.makedirs(save_dir, exist_ok=True)
if config.fairseq.common.tensorboard_logdir is not None:
tb_run_dir = suffix_rundir(
save_dir, config.fairseq.common.tensorboard_logdir
)
config.fairseq.common.tensorboard_logdir = tb_run_dir
print(
"update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir
)
os.makedirs(save_dir, exist_ok=True)
OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml"))
if overwrite_fairseq and config.fairseq is not None and args is not None:
# flatten fields.
for group in config.fairseq:
for field in config.fairseq[group]:
print("overwrite args." + field, "as", config.fairseq[group][field])
setattr(args, field, config.fairseq[group][field])
return config
def recursive_config(config_path):
"""allows for stacking of configs in any depth."""
config = OmegaConf.load(config_path)
if config.includes is not None:
includes = config.includes
config.pop("includes")
base_config = recursive_config(includes)
config = OmegaConf.merge(base_config, config)
return config
def suffix_rundir(save_dir, run_dir):
max_id = -1
for search_dir in os.listdir(save_dir):
if search_dir.startswith(run_dir):
splits = search_dir.split("_")
cur_id = int(splits[1]) if len(splits) > 1 else 0
max_id = max(max_id, cur_id)
return os.path.join(save_dir, run_dir + "_" + str(max_id + 1))
def overwrite_dir(config, replace, basedir):
for key in config:
if isinstance(config[key], str) and config[key].startswith(basedir):
config[key] = config[key].replace(basedir, replace)
if isinstance(config[key], omegaconf.dictconfig.DictConfig):
overwrite_dir(config[key], replace, basedir)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import pickle
import numpy as np
class ShardedTensor(object):
def __init__(self, data, starts):
self.data = data
self.starts = starts
assert self.starts[0] == 0
assert self.starts[-1] == len(self.data)
assert (self.starts[1:] >= self.starts[:-1]).all()
assert (self.starts > -1).all()
@staticmethod
def from_list(xs):
starts = np.full((len(xs) + 1,), -1, dtype=np.long)
data = np.concatenate(xs, axis=0)
starts[0] = 0
for i, x in enumerate(xs):
starts[i + 1] = starts[i] + x.shape[0]
assert (starts > -1).all()
return ShardedTensor(data, starts)
def __getitem__(self, i):
return self.data[self.starts[i] : self.starts[i + 1]]
def __len__(self):
return len(self.starts) - 1
def lengths(self):
return self.starts[1:] - self.starts[:-1]
def save(self, path):
np.save(path + "_starts", self.starts)
np.save(path + "_data", self.data)
@staticmethod
def load(path, mmap_mode=None):
starts = np.load(path + "_starts.npy", mmap_mode)
data = np.load(path + "_data.npy", mmap_mode)
return ShardedTensor(data, starts)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from mmpt.utils import recursive_config
class BaseJob(object):
def __init__(self, yaml_file, dryrun=False):
self.yaml_file = yaml_file
self.config = recursive_config(yaml_file)
self.dryrun = dryrun
def submit(self, **kwargs):
raise NotImplementedError
def _normalize_cmd(self, cmd_list):
cmd_list = list(cmd_list)
yaml_index = cmd_list.index("[yaml]")
cmd_list[yaml_index] = self.yaml_file
return cmd_list
class LocalJob(BaseJob):
CMD_CONFIG = {
"local_single": [
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
"--task", "mmtask", "--arch", "mmarch",
"--criterion", "mmloss",
],
"local_small": [
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
"--task", "mmtask", "--arch", "mmarch",
"--criterion", "mmloss",
"--distributed-world-size", "2"
],
"local_big": [
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
"--task", "mmtask", "--arch", "mmarch",
"--criterion", "mmloss",
"--distributed-world-size", "8"
],
"local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
}
def __init__(self, yaml_file, job_type=None, dryrun=False):
super().__init__(yaml_file, dryrun)
if job_type is None:
self.job_type = "local_single"
if self.config.task_type is not None:
self.job_type = self.config.task_type
else:
self.job_type = job_type
if self.job_type in ["local_single", "local_small"]:
if self.config.fairseq.dataset.batch_size > 32:
print("decreasing batch_size to 32 for local testing?")
def submit(self):
cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
if "predict" not in self.job_type:
# append fairseq args.
from mmpt.utils import load_config
config = load_config(config_file=self.yaml_file)
for field in config.fairseq:
for key in config.fairseq[field]:
if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag.
param = ["--" + key.replace("_", "-")]
else:
if key == "lr":
value = str(config.fairseq[field][key][0])
elif key == "adam_betas":
value = "'"+str(config.fairseq[field][key])+"'"
else:
value = str(config.fairseq[field][key])
param = [
"--" + key.replace("_", "-"),
value
]
cmd_list.extend(param)
print("launching", " ".join(cmd_list))
if not self.dryrun:
os.system(" ".join(cmd_list))
return JobStatus("12345678")
class JobStatus(object):
def __init__(self, job_id):
self.job_id = job_id
def __repr__(self):
return self.job_id
def __str__(self):
return self.job_id
def done(self):
return False
def running(self):
return False
def result(self):
if self.done():
return "{} is done.".format(self.job_id)
else:
return "{} is running.".format(self.job_id)
def stderr(self):
return self.result()
def stdout(self):
return self.result()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import argparse
import pprint
import omegaconf
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from mmpt.utils import load_config, set_seed
from mmpt.evaluators import Evaluator
from mmpt.evaluators import predictor as predictor_path
from mmpt.tasks import Task
from mmpt import processors
from mmpt.datasets import MMDataset
def get_dataloader(config):
meta_processor_cls = getattr(processors, config.dataset.meta_processor)
video_processor_cls = getattr(processors, config.dataset.video_processor)
text_processor_cls = getattr(processors, config.dataset.text_processor)
aligner_cls = getattr(processors, config.dataset.aligner)
meta_processor = meta_processor_cls(config.dataset)
video_processor = video_processor_cls(config.dataset)
text_processor = text_processor_cls(config.dataset)
aligner = aligner_cls(config.dataset)
test_data = MMDataset(
meta_processor,
video_processor,
text_processor,
aligner,
)
print("test_len", len(test_data))
output = test_data[0]
test_data.print_example(output)
test_dataloader = DataLoader(
test_data,
batch_size=config.fairseq.dataset.batch_size,
shuffle=False,
num_workers=6,
collate_fn=test_data.collater,
)
return test_dataloader
def main(args):
config = load_config(args)
if isinstance(config, omegaconf.dictconfig.DictConfig):
print(OmegaConf.to_yaml(config))
else:
pp = pprint.PrettyPrinter(indent=4)
pp.print(config)
mmtask = Task.config_task(config)
mmtask.build_model()
test_dataloader = get_dataloader(config)
checkpoint_search_path = os.path.dirname(config.eval.save_path)
results = []
prefix = os.path.basename(args.taskconfig)
if prefix.startswith("test"):
# loop all checkpoint for datasets without validation set.
if "best" not in config.fairseq.common_eval.path:
print("eval each epoch.")
for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
model = mmtask.load_checkpoint(checkpoint)
ckpt = os.path.basename(checkpoint)
evaluator = Evaluator(config)
output = evaluator.evaluate(
model, test_dataloader, ckpt + "_merged")
results.append((checkpoint, output))
# use the one specified by the config lastly.
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
evaluator = Evaluator(config)
output = evaluator.evaluate(model, test_dataloader)
results.append((config.fairseq.common_eval.path, output))
best_result = None
best_metric = 0.
for checkpoint, result in results:
print(checkpoint)
evaluator.metric.print_computed_metrics(result)
best_score = evaluator.metric.best_metric(result)
if best_score > best_metric:
best_result = (checkpoint, result)
best_metric = best_score
print("best results:")
print(best_result[0])
evaluator.metric.print_computed_metrics(best_result[1])
elif prefix.startswith("vis"):
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
predictor_cls = getattr(predictor_path, config.predictor)
predictor = predictor_cls(config)
predictor.predict_loop(model, test_dataloader, mmtask, None)
else:
raise ValueError("unknown prefix of the config file", args.taskconfig)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("taskconfig", type=str)
args = parser.parse_args()
main(args)
# Pretraining
(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
We mostly use [howto100M](https://github.com/antoine77340/howto100m) dataset for pretraining (other datasets are coming). So you are less likely to write a new `MetaProcessor`, `VideoProcessor` or `TextProcessor` but only working on a new `Aligner`, a new model and loss.
### Data Sharding
Pretraining on Howto100M is heavy on IO since we have millions of videos or captions on the hard disk that cannot be fit into the memory.
It is desirable to have an optimized preprocessing step before the actual dataloading.
We support data sharding to pack multiple videos into a shards of training data for both videos and captions. (see [dataset](DATASET.md) for preprocessing).
These shards will be mapped into memory to reduce the frequency of IO access on millions of files. See (processors starting with `Sharded*`).
This will be the default config for a how2 dataset `projects/task/how2.yaml`.
Great thanks to Dmytro Okhonko for sharing the code from MARGE project.
### Training
Pretraining on Howto100m is expected on one or multiple nodes, where each node has 8 GPUS with 32 GB mem.
launching a pretraing on MFM+MLM can be done, via:
```python locallaunch.py projects/mfmmlm/how2.yaml```
### Pre-training with a Retrieval Model (VideoCLIP)
This projects now support alternatively run a retrieval model and pre-training.
We implement a basic retrieval model that is built on the hidden states of a video and faiss.
You may need to install faiss via `conda install faiss-cpu -c pytorch`.
Right now, the hidden states of a video is computed as the average of 8 clips of their pooled visual/text hidden states.
See `mmpt/tasks/retritask.py` for more details.
The `.yaml` config for running pre-training with a retrieval model can be found at `projects/retri/videoretri.yaml`.
project_dir: mfmmlm
run_task:
- how2.yaml
- [vtt.yaml, vttcap.yaml, vttqa.yaml, youcook.yaml, youcookcap.yaml, crosstask.yaml, coin.yaml]
base_dir: task
task_group:
pretrain:
task_list:
- how2.yaml
dataset:
subsampling: 32
sampled_min_len: 10
sampled_max_len: 64
max_video_len: 32
max_len: 96
aligner: MFMMLMAligner
lazy_vfeat_mask: True
mfm_probability: 0.15
mlm_probability: 0.15
mm_prob: 0.5
model:
model_cls: MMFusionMFMMLM
mm_encoder_cls: MMFusionForMFMMLM
loss:
loss_cls: MFMMLM
fairseq:
common:
fp16: true
dataset:
batch_size: 256
optimization:
max_epoch: 15
finetune:
task_list:
- vtt.yaml
- vttqa.yaml
- youcook.yaml
- youcookcap.yaml
- crosstask.yaml
- coin.yaml
dataset:
max_video_len: 32
max_len: 96
fairseq:
common:
fp16: true
# do not write any model or loss here (they are expected to be fixed in mmfusion).
test:
task_list:
- test_vtt.yaml
- test_vttqa.yaml
- test_youcook.yaml
- test_youcookcap.yaml
- test_crosstask.yaml
- test_crosstask_zs.yaml
- test_coin.yaml
dataset:
max_video_len: 32
max_len: 96
includes: projects/mfmmlm.yaml
project_dir: mtm/mmfusionmtm
task_group:
pretrain:
task: VLMTask # reproducible
dataset:
aligner: MFMMLMAligner
model:
use_seg_emb: True # reproducible
model_cls: MMFusionMTM
mm_encoder_cls: MMBertForMFMMLM
loss:
loss_cls: MTM
finetune:
model:
use_seg_emb: True # reproducible
test:
model:
use_seg_emb: True # reproducible
includes: projects/mtm/mmfusionmtm.yaml
project_dir: mtm/vlm
task_group:
pretrain:
dataset:
sampled_min_len: 8
loss:
loss_cls: MTM
dataset:
video_processor: VideoProcessor
bert_name: bert-base-uncased
meta_processor: COINActionSegmentationMetaProcessor
train_path: data/coin/COIN.json
val_path: data/coin/COIN.json
vfeat_dir: data/feat/feat_coin_s3d
text_processor: COINActionSegmentationTextProcessor
aligner: COINActionSegmentationAligner
num_iso_layer: 12
sliding_window: 8
sliding_window_size: 32
max_video_len: 32
max_len: 96
fairseq:
common:
tensorboard_logdir: run
log_interval: 1000
fp16: true
dataset:
num_workers: 4
batch_size: 1
optimization:
lr:
- 5.0e-05
clip_norm: 2.0
optimizer: adam
adam_betas: (0.9, 0.98)
lr_scheduler: polynomial_decay
total_num_update: 1000000
warmup_updates: 122
weight_decay: 0.0
ddp_backend: no_c10d
max_epoch: 8
checkpoint:
restore_file: runs/mtm/vlm/checkpoint_best.pt
reset_optimizer: true
reset_dataloader: true
reset_meters: true
save_dir: runs/mtm/vlm/coin
task_type: sweep_big
model:
model_cls: MMFusionActionSegmentation
mm_encoder_cls: MMBertForTokenClassification
use_seg_emb: true
loss:
loss_cls: CrossEntropy
dataset:
video_processor: CrossTaskVideoProcessor
bert_name: bert-base-uncased
meta_processor: CrossTaskMetaProcessor
train_path: data/crosstask/crosstask_release/videos.csv
train_csv_path: data/crosstask/crosstask_release/videos.csv
val_path: data/crosstask/crosstask_release/videos_val.csv
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
related_path: data/crosstask/crosstask_release/tasks_related.txt
vfeat_dir: data/feat/feat_crosstask_s3d
annotation_path: data/crosstask/crosstask_release/annotations
n_train: 30
text_processor: CrossTaskTextProcessor
aligner: CrossTaskAligner
num_iso_layer: 12
sliding_window: 16
sliding_window_size: 32
max_video_len: 32
max_len: 96
fairseq:
common:
tensorboard_logdir: run
log_interval: 1000
fp16: true
dataset:
num_workers: 4
batch_size: 1
optimization:
lr:
- 5.0e-05
clip_norm: 2.0
optimizer: adam
adam_betas: (0.9, 0.98)
lr_scheduler: polynomial_decay
total_num_update: 1000000
warmup_updates: 122
weight_decay: 0.0
ddp_backend: no_c10d
max_epoch: 5
checkpoint:
restore_file: runs/mtm/vlm/checkpoint11.pt
reset_optimizer: true
reset_dataloader: true
reset_meters: true
save_dir: runs/mtm/vlm/crosstask
task_type: sweep_small
model:
model_cls: MMFusionActionLocalization
mm_encoder_cls: MMBertForJoint
use_seg_emb: true
loss:
loss_cls: BCE
dataset:
video_processor: ShardedVideoProcessor
bert_name: bert-base-uncased
meta_processor: ShardedHow2MetaProcessor
train_path: data/how2/how2_s3d_train.lst
val_path: data/how2/how2_s3d_val.lst
vfeat_dir: data/feat/feat_how2_s3d_shard_small
text_processor: ShardedTextProcessor
tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
aligner: MFMMLMAligner
subsampling: 32
sampled_min_len: 8
sampled_max_len: 64
max_video_len: 32
max_len: 96
lazy_vfeat_mask: true
mfm_probability: 0.15
mlm_probability: 0.15
mm_prob: 0.5
fairseq:
common:
tensorboard_logdir: run
log_interval: 1000
fp16: true
dataset:
num_workers: 4
batch_size: 256
optimization:
lr:
- 5.0e-05
clip_norm: 2.0
optimizer: adam
adam_betas: (0.9, 0.98)
lr_scheduler: polynomial_decay
total_num_update: 1000000
warmup_updates: 1000
weight_decay: 0.0
ddp_backend: no_c10d
max_epoch: 15
checkpoint:
save_dir: runs/mtm/vlm
save_interval_updates: 1024
keep_interval_updates: 2
keep_last_epochs: 30
task_type: sweep_big
slurm_config: big
eval:
save_path: runs/mtm/vlm
model:
model_cls: MMFusionMTM
mm_encoder_cls: MMBertForMFMMLM
use_seg_emb: true
loss:
loss_cls: MTM
task: VLMTask
slurm_config: big
task_type: local_predict
dataset:
split: test
video_processor: VideoProcessor
aligner: COINActionSegmentationAligner
bert_name: bert-base-uncased
test_path: data/coin/COIN.json
meta_processor: COINActionSegmentationMetaProcessor
vfeat_dir: data/feat/feat_coin_s3d
text_processor: COINActionSegmentationTextProcessor
num_iso_layer: 12
sliding_window: 16
sliding_window_size: 32
max_video_len: 32
max_len: 96
fairseq:
dataset:
batch_size: 1
valid_subset: test
num_workers: 2
common_eval:
path: runs/mtm/vlm/coin/checkpoint_best.pt
model:
model_cls: MMFusionActionSegmentation
mm_encoder_cls: MMBertForTokenClassification
use_seg_emb: true
eval:
save_path: runs/mtm/vlm/coin/eval
metric: COINActionSegmentationMetric
predictor: COINPredictor
slurm_config: big
task_type: local_predict
dataset:
split: test
video_processor: CrossTaskVideoProcessor
aligner: CrossTaskAligner
bert_name: bert-base-uncased
meta_processor: CrossTaskMetaProcessor
test_path: data/crosstask/crosstask_release/videos_val.csv
train_csv_path: data/crosstask/crosstask_release/videos.csv
val_path: data/crosstask/crosstask_release/videos_val.csv
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
related_path: data/crosstask/crosstask_release/tasks_related.txt
vfeat_dir: data/feat/feat_crosstask_s3d
annotation_path: data/crosstask/crosstask_release/annotations
n_train: 30
text_processor: CrossTaskTextProcessor
num_iso_layer: 12
sliding_window: 16
sliding_window_size: 32
max_video_len: 32
max_len: 96
fairseq:
dataset:
batch_size: 1
valid_subset: test
num_workers: 2
common_eval:
path: runs/mtm/vlm/crosstask/checkpoint_best.pt
model:
model_cls: MMFusionActionLocalization
mm_encoder_cls: MMBertForJoint
use_seg_emb: true
eval:
save_path: runs/mtm/vlm/crosstask/eval
metric: CrossTaskMetric
predictor: CrossTaskPredictor
slurm_config: big
task_type: local_predict
dataset:
split: test
video_processor: CrossTaskVideoProcessor
aligner: CrossTaskAligner
bert_name: bert-base-uncased
meta_processor: CrossTaskMetaProcessor
test_path: data/crosstask/crosstask_release/videos_val.csv
train_csv_path: data/crosstask/crosstask_release/videos.csv
val_path: data/crosstask/crosstask_release/videos_val.csv
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
related_path: data/crosstask/crosstask_release/tasks_related.txt
vfeat_dir: data/feat/feat_crosstask_s3d
annotation_path: data/crosstask/crosstask_release/annotations
n_train: 30
text_processor: CrossTaskTextProcessor
num_iso_layer: 12
sliding_window: 16
sliding_window_size: 32
max_video_len: 32
max_len: 96
fairseq:
dataset:
batch_size: 1
valid_subset: test
num_workers: 2
common_eval:
path: runs/mtm/vlm/checkpoint_best.pt
model:
model_cls: MMFusionActionLocalization
mm_encoder_cls: MMBertForJoint
use_seg_emb: true
eval:
save_path: runs/mtm/vlm/crosstask_zs/eval
metric: CrossTaskMetric
predictor: CrossTaskPredictor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment