Unverified Commit 5c0b7f31 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Video reference scripts (#1180)

* Copy classification scripts for video classification

* Initial version of video classification

* add version

* Training of r2plus1d_18 on kinetics work

Gives even slightly better results than expected, with 57.336 top1 clip accuracy. But we count some clips twice in this evaluation

* Cleanups on training script

* Lint

* Minor improvements

* Remove some hacks

* Lint
parent 2287c8f2
import math
import torch
from torch.utils.data import Sampler
import torch.distributed as dist
import torchvision.datasets.video_utils
class DistributedSampler(Sampler):
"""
Extension of DistributedSampler, as discussed in
https://github.com/pytorch/pytorch/issues/23430
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if isinstance(self.dataset, Sampler):
orig_indices = list(iter(self.dataset))
indices = [orig_indices[i] for i in indices]
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class UniformClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video, equally spaced
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, uniformly spaced
for c in self.video_clips.clips:
length = len(c)
step = max(length // self.max_clips_per_video, 1)
sampled = torch.arange(length)[::step] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs).tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
import torch
from bisect import bisect_right
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer,
milestones,
gamma=0.1,
warmup_factor=1.0 / 3,
warmup_iters=5,
warmup_method="linear",
last_epoch=-1,
):
if not milestones == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}",
milestones,
)
if warmup_method not in ("constant", "linear"):
raise ValueError(
"Only 'constant' or 'linear' warmup_method accepted"
"got {}".format(warmup_method)
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
warmup_factor = 1
if self.last_epoch < self.warmup_iters:
if self.warmup_method == "constant":
warmup_factor = self.warmup_factor
elif self.warmup_method == "linear":
alpha = float(self.last_epoch) / self.warmup_iters
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
return [
base_lr *
warmup_factor *
self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
from __future__ import print_function
import datetime
import os
import time
import sys
import torch
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch import nn
import torchvision
import torchvision.datasets.video_utils
from torchvision import transforms
import utils
from sampler import DistributedSampler, UniformClipSampler
from scheduler import WarmupMultiStepLR
import transforms as T
try:
from apex import amp
except ImportError:
amp = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}'))
header = 'Epoch: [{}]'.format(epoch)
for video, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
video, target = video.to(device), target.to(device)
output = model(video)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = video.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time))
lr_scheduler.step()
def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for video, target in metric_logger.log_every(data_loader, 100, header):
video = video.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(video)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = video.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5))
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def collate_fn(batch):
# remove audio from the batch
batch = [(d[0], d[2]) for d in batch]
return default_collate(batch)
def main(args):
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
print("torch version: ", torch.__version__)
print("torchvision version: ", torchvision.__version__)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
# Data loading code
print("Loading data")
traindir = os.path.join(args.data_path, 'train_avi-480p')
valdir = os.path.join(args.data_path, 'val_avi-480p')
normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
std=[0.22803, 0.22145, 0.216989])
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
transform_train = torchvision.transforms.Compose([
T.ToFloatTensorInZeroOne(),
T.Resize((128, 171)),
T.RandomHorizontalFlip(),
normalize,
T.RandomCrop((112, 112))
])
if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
dataset.transform = transform_train
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache "
"on a single-gpu first, as it will be faster")
dataset = torchvision.datasets.KineticsVideo(
traindir,
frames_per_clip=args.clip_len,
step_between_clips=1,
transform=transform_train
)
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
dataset.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
transform_test = torchvision.transforms.Compose([
T.ToFloatTensorInZeroOne(),
T.Resize((128, 171)),
normalize,
T.CenterCrop((112, 112))
])
if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
dataset_test.transform = transform_test
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache "
"on a single-gpu first, as it will be faster")
dataset_test = torchvision.datasets.KineticsVideo(
valdir,
frames_per_clip=args.clip_len,
step_between_clips=1,
transform=transform_test
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)
print("Creating data loaders")
train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video)
test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
if args.distributed:
train_sampler = DistributedSampler(train_sampler)
test_sampler = DistributedSampler(test_sampler)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers,
pin_memory=True, collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers,
pin_memory=True, collate_fn=collate_fn)
print("Creating model")
# model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
model = torchvision.models.video.__dict__[args.model]()
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
lr = args.lr * args.world_size
optimizer = torch.optim.SGD(
model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs
warmup_iters = args.lr_warmup_epochs * len(data_loader)
lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
lr_scheduler = WarmupMultiStepLR(
optimizer, milestones=lr_milestones, gamma=args.lr_gamma,
warmup_iters=warmup_iters, warmup_factor=1e-5)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader,
device, epoch, args.print_freq, args.apex)
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset')
parser.add_argument('--model', default='r2plus1d_18', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('--clip-len', default=16, type=int, metavar='N',
help='number of frames per clip')
parser.add_argument('--clips-per-video', default=5, type=int, metavar='N',
help='maximum number of clips per video to consider')
parser.add_argument('-b', '--batch-size', default=24, type=int)
parser.add_argument('--epochs', default=45, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)
import torch
import random
def crop(vid, i, j, h, w):
return vid[..., i:(i + h), j:(j + w)]
def center_crop(vid, output_size):
h, w = vid.shape[-2:]
th, tw = output_size
i = int(round((h - th) / 2.))
j = int(round((w - tw) / 2.))
return crop(vid, i, j, th, tw)
def hflip(vid):
return vid.flip(dims=(-1,))
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation='bilinear'):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(
vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)
def pad(vid, padding, fill=0, padding_mode="constant"):
# NOTE: don't want to pad on temporal dimension, so let as non-batch
# (4d) before padding. This works as expected
return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode)
def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
def normalize(vid, mean, std):
shape = (-1,) + (1,) * (vid.dim() - 1)
mean = torch.as_tensor(mean).reshape(shape)
std = torch.as_tensor(std).reshape(shape)
return (vid - mean) / std
# Class interface
class RandomCrop(object):
def __init__(self, size):
self.size = size
@staticmethod
def get_params(vid, output_size):
"""Get parameters for ``crop`` for a random crop.
"""
h, w = vid.shape[-2:]
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, vid):
i, j, h, w = self.get_params(vid, self.size)
return crop(vid, i, j, h, w)
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return center_crop(vid, self.size)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size)
class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, vid):
return normalize(vid, self.mean, self.std)
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, vid):
if random.random() < self.p:
return hflip(vid)
return vid
class Pad(object):
def __init__(self, padding, fill=0):
self.padding = padding
self.fill = fill
def __call__(self, vid):
return pad(vid, self.padding, self.fill)
from __future__ import print_function
from collections import defaultdict, deque
import datetime
import time
import torch
import torch.distributed as dist
import errno
import os
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)
...@@ -5,7 +5,7 @@ from .vision import VisionDataset ...@@ -5,7 +5,7 @@ from .vision import VisionDataset
class KineticsVideo(VisionDataset): class KineticsVideo(VisionDataset):
def __init__(self, root, frames_per_clip, step_between_clips=1): def __init__(self, root, frames_per_clip, step_between_clips=1, transform=None):
super(KineticsVideo, self).__init__(root) super(KineticsVideo, self).__init__(root)
extensions = ('avi',) extensions = ('avi',)
...@@ -15,6 +15,7 @@ class KineticsVideo(VisionDataset): ...@@ -15,6 +15,7 @@ class KineticsVideo(VisionDataset):
self.classes = classes self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
self.transform = transform
def __len__(self): def __len__(self):
return self.video_clips.num_clips() return self.video_clips.num_clips()
...@@ -23,4 +24,7 @@ class KineticsVideo(VisionDataset): ...@@ -23,4 +24,7 @@ class KineticsVideo(VisionDataset):
video, audio, info, video_idx = self.video_clips.get_clip(idx) video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1] label = self.samples[video_idx][1]
if self.transform is not None:
video = self.transform(video)
return video, audio, label return video, audio, label
...@@ -4,6 +4,8 @@ import torch ...@@ -4,6 +4,8 @@ import torch
import torch.utils.data import torch.utils.data
from torchvision.io import read_video_timestamps, read_video from torchvision.io import read_video_timestamps, read_video
from .utils import tqdm
def unfold(tensor, size, step, dilation=1): def unfold(tensor, size, step, dilation=1):
""" """
...@@ -59,11 +61,33 @@ class VideoClips(object): ...@@ -59,11 +61,33 @@ class VideoClips(object):
def _compute_frame_pts(self): def _compute_frame_pts(self):
self.video_pts = [] self.video_pts = []
self.video_fps = [] self.video_fps = []
# TODO maybe paralellize this
for video_file in self.video_paths: # strategy: use a DataLoader to parallelize read_video_timestamps
clips, fps = read_video_timestamps(video_file) # so need to create a dummy dataset first
self.video_pts.append(torch.as_tensor(clips)) class DS(object):
self.video_fps.append(fps) def __init__(self, x):
self.x = x
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return read_video_timestamps(self.x[idx])
import torch.utils.data
dl = torch.utils.data.DataLoader(
DS(self.video_paths),
batch_size=16,
num_workers=torch.get_num_threads(),
collate_fn=lambda x: x)
with tqdm(total=len(dl)) as pbar:
for batch in dl:
pbar.update(1)
clips, fps = list(zip(*batch))
clips = [torch.as_tensor(c) for c in clips]
self.video_pts.extend(clips)
self.video_fps.extend(fps)
def _init_from_metadata(self, metadata): def _init_from_metadata(self, metadata):
assert len(self.video_paths) == len(metadata["video_pts"]) assert len(self.video_paths) == len(metadata["video_pts"])
......
...@@ -21,7 +21,7 @@ install PyAV on your system. ...@@ -21,7 +21,7 @@ install PyAV on your system.
# PyAV has some reference cycles # PyAV has some reference cycles
_CALLED_TIMES = 0 _CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20 _GC_COLLECTION_INTERVAL = 10
def write_video(filename, video_array, fps, video_codec='libx264', options=None): def write_video(filename, video_array, fps, video_codec='libx264', options=None):
...@@ -95,7 +95,8 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): ...@@ -95,7 +95,8 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
# TODO check if stream needs to always be the video stream here or not # TODO check if stream needs to always be the video stream here or not
container.seek(seek_offset, any_frame=False, backward=True, stream=stream) container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
except av.AVError: except av.AVError:
print("Corrupted file?", container.name) # TODO add some warnings in this case
# print("Corrupted file?", container.name)
return [] return []
buffer_count = 0 buffer_count = 0
for idx, frame in enumerate(container.decode(**stream_name)): for idx, frame in enumerate(container.decode(**stream_name)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment