"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "87ec8048085a5d892e78e60cd1fb0b4e37219c1e"
Unverified Commit f95b0533 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Updated video classification ref example with new transforms (#2935)

* [WIP] Update ref example video classification

* [WIP] Updated video classification ref example

* Replaced mem format conversion functions by classes
parent 044fcf24
# Video Classification
TODO: Add some info about the context, dataset we use etc
## Data preparation
If you already have downloaded [Kinetics400 dataset](https://deepmind.com/research/open-source/kinetics),
please proceed directly to the next section.
To download videos, one can use https://github.com/Showmax/kinetics-downloader
## Training
We assume the training and validation AVI videos are stored at `/data/kinectics400/train` and
`/data/kinectics400/val`.
### Multiple GPUs
Run the training on a single node with 8 GPUs:
```bash
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=16 --cache-dataset --sync-bn --apex
```
### Single GPU
**Note:** training on a single gpu can be extremely slow.
```bash
python train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=8 --cache-dataset
```
...@@ -7,13 +7,13 @@ from torch.utils.data.dataloader import default_collate ...@@ -7,13 +7,13 @@ from torch.utils.data.dataloader import default_collate
from torch import nn from torch import nn
import torchvision import torchvision
import torchvision.datasets.video_utils import torchvision.datasets.video_utils
from torchvision import transforms from torchvision import transforms as T
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
import utils import utils
from scheduler import WarmupMultiStepLR from scheduler import WarmupMultiStepLR
import transforms as T from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
try: try:
from apex import amp from apex import amp
...@@ -119,11 +119,13 @@ def main(args): ...@@ -119,11 +119,13 @@ def main(args):
st = time.time() st = time.time()
cache_path = _get_cache_path(traindir) cache_path = _get_cache_path(traindir)
transform_train = torchvision.transforms.Compose([ transform_train = torchvision.transforms.Compose([
T.ToFloatTensorInZeroOne(), ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)), T.Resize((128, 171)),
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
normalize, normalize,
T.RandomCrop((112, 112)) T.RandomCrop((112, 112)),
ConvertBCHWtoCBHW()
]) ])
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
...@@ -139,7 +141,8 @@ def main(args): ...@@ -139,7 +141,8 @@ def main(args):
frames_per_clip=args.clip_len, frames_per_clip=args.clip_len,
step_between_clips=1, step_between_clips=1,
transform=transform_train, transform=transform_train,
frame_rate=15 frame_rate=15,
extensions=('avi', 'mp4', )
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path)) print("Saving dataset_train to {}".format(cache_path))
...@@ -152,10 +155,12 @@ def main(args): ...@@ -152,10 +155,12 @@ def main(args):
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
transform_test = torchvision.transforms.Compose([ transform_test = torchvision.transforms.Compose([
T.ToFloatTensorInZeroOne(), ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)), T.Resize((128, 171)),
normalize, normalize,
T.CenterCrop((112, 112)) T.CenterCrop((112, 112)),
ConvertBCHWtoCBHW()
]) ])
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
...@@ -171,7 +176,8 @@ def main(args): ...@@ -171,7 +176,8 @@ def main(args):
frames_per_clip=args.clip_len, frames_per_clip=args.clip_len,
step_between_clips=1, step_between_clips=1,
transform=transform_test, transform=transform_test,
frame_rate=15 frame_rate=15,
extensions=('avi', 'mp4',)
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print("Saving dataset_test to {}".format(cache_path))
...@@ -265,7 +271,7 @@ def main(args): ...@@ -265,7 +271,7 @@ def main(args):
def parse_args(): def parse_args():
import argparse import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser = argparse.ArgumentParser(description='PyTorch Video Classification Training')
parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset')
parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir')
......
import torch import torch
import random import torch.nn as nn
def crop(vid, i, j, h, w): class ConvertBHWCtoBCHW(nn.Module):
return vid[..., i:(i + h), j:(j + w)] """Convert tensor from (B, H, W, C) to (B, C, H, W)
"""
def forward(self, vid: torch.Tensor) -> torch.Tensor:
return vid.permute(0, 3, 1, 2)
def center_crop(vid, output_size):
h, w = vid.shape[-2:]
th, tw = output_size
i = int(round((h - th) / 2.)) class ConvertBCHWtoCBHW(nn.Module):
j = int(round((w - tw) / 2.)) """Convert tensor from (B, C, H, W) to (C, B, H, W)
return crop(vid, i, j, th, tw) """
def forward(self, vid: torch.Tensor) -> torch.Tensor:
def hflip(vid): return vid.permute(1, 0, 2, 3)
return vid.flip(dims=(-1,))
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation='bilinear'):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(
vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)
def pad(vid, padding, fill=0, padding_mode="constant"):
# NOTE: don't want to pad on temporal dimension, so let as non-batch
# (4d) before padding. This works as expected
return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode)
def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
def normalize(vid, mean, std):
shape = (-1,) + (1,) * (vid.dim() - 1)
mean = torch.as_tensor(mean).reshape(shape)
std = torch.as_tensor(std).reshape(shape)
return (vid - mean) / std
# Class interface
class RandomCrop(object):
def __init__(self, size):
self.size = size
@staticmethod
def get_params(vid, output_size):
"""Get parameters for ``crop`` for a random crop.
"""
h, w = vid.shape[-2:]
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, vid):
i, j, h, w = self.get_params(vid, self.size)
return crop(vid, i, j, h, w)
class CenterCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return center_crop(vid, self.size)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size)
class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, vid):
return normalize(vid, self.mean, self.std)
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, vid):
if random.random() < self.p:
return hflip(vid)
return vid
class Pad(object):
def __init__(self, padding, fill=0):
self.padding = padding
self.fill = fill
def __call__(self, vid):
return pad(vid, self.padding, self.fill)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment