Commit 3d92aebb authored by bailuo's avatar bailuo
Browse files

add preprocessing

parent fcc0bcf3
Pipeline #1379 canceled with stages
BSD 3-Clause License
Copyright (c) 2020, princeton-vl
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# RAFT
This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
ECCV 2020 <br/>
Zachary Teed and Jia Deng<br/>
<img src="RAFT.png">
## Requirements
The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda activate raft
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
```
## Demos
Pretrained models can be downloaded by running
```Shell
./download_models.sh
```
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
You can demo a trained model on a sequence of frames
```Shell
python demo.py --model=models/raft-things.pth --path=demo-frames
```
## Required Data
To evaluate/train RAFT, you will need to download the required datasets.
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
* [Sintel](http://sintel.is.tue.mpg.de/)
* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
```Shell
├── datasets
├── Sintel
├── test
├── training
├── KITTI
├── testing
├── training
├── devkit
├── FlyingChairs_release
├── data
├── FlyingThings3D
├── frames_cleanpass
├── frames_finalpass
├── optical_flow
```
## Evaluation
You can evaluate a trained model using `evaluate.py`
```Shell
python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
```
## Training
We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
```Shell
./train_standard.sh
```
If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell
./train_mixed.sh
```
## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations
std::vector<torch::Tensor> corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius);
std::vector<torch::Tensor> corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius);
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> corr_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
return corr_cuda_forward(fmap1, fmap2, coords, radius);
}
std::vector<torch::Tensor> corr_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
CHECK_INPUT(corr_grad);
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &corr_forward, "CORR forward");
m.def("backward", &corr_backward, "CORR backward");
}
\ No newline at end of file
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#define BLOCK_H 4
#define BLOCK_W 8
#define BLOCK_HW BLOCK_H * BLOCK_W
#define CHANNEL_STRIDE 32
__forceinline__ __device__
bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
__global__ void corr_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
}
__syncthreads();
for (int n=0; n<N; n++) {
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
if (within_bounds(h1, w1, H1, W1)) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
}
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
}
__syncthreads();
scalar_t s = 0.0;
for (int k=0; k<CHANNEL_STRIDE; k++)
s += f1[k][tid] * f2[k][tid];
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
scalar_t nw = s * (dy) * (dx);
scalar_t ne = s * (dy) * (1-dx);
scalar_t sw = s * (1-dy) * (dx);
scalar_t se = s * (1-dy) * (1-dx);
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_nw) += nw;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_ne) += ne;
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_sw) += sw;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_se) += se;
}
}
}
}
}
template <typename scalar_t>
__global__ void corr_backward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
f1_grad[c1][k1] = 0.0;
}
__syncthreads();
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
for (int n=0; n<N; n++) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
f2_grad[c2][k1] = 0.0;
}
__syncthreads();
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
scalar_t g = 0.0;
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_nw) * dy * dx;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_ne) * dy * (1-dx);
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
for (int k=0; k<CHANNEL_STRIDE; k++) {
f1_grad[k][tid] += g * f2[k][tid];
f2_grad[k][tid] += g * f1[k][tid];
}
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
if (within_bounds(h2, w2, H2, W2))
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
}
}
}
}
__syncthreads();
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
if (within_bounds(h1, w1, H1, W1))
fptr[c+c1] += f1_grad[c1][k1];
}
}
}
std::vector<torch::Tensor> corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H = coords.size(2);
const auto W = coords.size(3);
const auto rd = 2 * radius + 1;
auto opts = fmap1.options();
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
corr_forward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {corr};
}
std::vector<torch::Tensor> corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H1 = fmap1.size(1);
const auto W1 = fmap1.size(2);
const auto H2 = fmap2.size(1);
const auto W2 = fmap2.size(2);
const auto C = fmap1.size(3);
auto opts = fmap1.options();
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
corr_backward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {fmap1_grad, fmap2_grad, coords_grad};
}
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='correlation',
ext_modules=[
CUDAExtension('alt_cuda_corr',
sources=['correlation.cpp', 'correlation_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
],
cmdclass={
'build_ext': BuildExtension
})
"""
chain cycle consistent correspondences to create longer and denser tracks.
the rules for chaining: only chain cycle consistent flows between adjacent frames,
and if the direct cycle consistent flow exists, the chained flows will be overwritten by the direct flows
which are considered to be more reliable, and the procedure then continues iteratively.
The chained flows will go through a final appearance check in both feature and RGB space
to reduce spurious correspondences. One can think of this process as augmenting the original
direct optical flows (which is unchanged) with some chained ones. This is optional and we found it
to help with the optimization process for sequences where the number of valid flows are very imbalanced
across regions.
"""
import json
import argparse
import os
import glob
import imageio
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
DEVICE = 'cuda'
def gen_grid(h, w, device, normalize=False, homogeneous=False):
if normalize:
lin_y = torch.linspace(-1., 1., steps=h, device=device)
lin_x = torch.linspace(-1., 1., steps=w, device=device)
else:
lin_y = torch.arange(0, h, device=device)
lin_x = torch.arange(0, w, device=device)
grid_y, grid_x = torch.meshgrid((lin_y, lin_x))
grid = torch.stack((grid_x, grid_y), -1)
if homogeneous:
grid = torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
return grid # [h, w, 2 or 3]
def normalize_coords(coords, h, w, no_shift=False):
assert coords.shape[-1] == 2
if no_shift:
return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2
else:
return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2 - 1.
def run(args):
feature_name = 'dino'
scene_dir = args.data_dir
print('chaining raft optical flow for {}....'.format(scene_dir))
img_files = sorted(glob.glob(os.path.join(scene_dir, 'color', '*')))
num_imgs = len(img_files)
pbar = tqdm(total=num_imgs*(num_imgs-1))
out_dir = os.path.join(scene_dir, 'raft_exhaustive')
out_mask_dir = os.path.join(scene_dir, 'raft_masks')
count_out_dir = os.path.join(scene_dir, 'count_maps')
out_flow_stats_file = os.path.join(scene_dir, 'flow_stats.json')
os.makedirs(out_dir, exist_ok=True)
os.makedirs(out_mask_dir, exist_ok=True)
os.makedirs(count_out_dir, exist_ok=True)
h, w = imageio.imread(img_files[0]).shape[:2]
grid = gen_grid(h, w, 'cuda')[None] # [b, h, w, 2]
flow_stats = {}
count_maps = np.zeros((num_imgs, h, w), dtype=np.uint16)
images = [torch.from_numpy(imageio.imread(img_file) / 255.).float().permute(2, 0, 1)[None].to(DEVICE)
for img_file in img_files]
features = [torch.from_numpy(np.load(os.path.join(scene_dir, 'features', feature_name,
os.path.basename(img_file) + '.npy'))).float().to(DEVICE)
for img_file in img_files]
for i in range(num_imgs - 1):
imgname_i = os.path.basename(img_files[i])
imgname_i_plus_1 = os.path.basename(img_files[i + 1])
start_flow_file = os.path.join(scene_dir, 'raft_exhaustive', '{}_{}.npy'.format(imgname_i, imgname_i_plus_1))
start_flow = np.load(start_flow_file)
start_flow = torch.from_numpy(start_flow).float()[None].cuda() # [b, h, w, 2]
start_mask_file = start_flow_file.replace('raft_exhaustive', 'raft_masks').replace('.npy', '.png')
start_cycle_mask = imageio.imread(start_mask_file)[..., 0] > 0
feature_i = features[i].permute(2, 0, 1)[None]
feature_i = F.interpolate(feature_i, size=start_flow.shape[1:3], mode='bilinear')
accumulated_flow = start_flow
accumulated_cycle_mask = start_cycle_mask
for j in range(i + 1, num_imgs):
# # vis
imgname_j = os.path.basename(img_files[j])
direct_flow_file = os.path.join(scene_dir, 'raft_exhaustive', '{}_{}.npy'.format(imgname_i, imgname_j))
direct_flow = np.load(direct_flow_file)
direct_flow = torch.from_numpy(direct_flow).float()[None].cuda() # [b, h, w, 2]
direct_mask_file = direct_flow_file.replace('raft_exhaustive', 'raft_masks').replace('.npy', '.png')
direct_masks = imageio.imread(direct_mask_file)
direct_cycle_mask = direct_masks[..., 0] > 0
direct_occlusion_mask = direct_masks[..., 1] > 0
direct_mask = direct_cycle_mask | direct_occlusion_mask
direct_mask = torch.from_numpy(direct_mask)[None]
accumulated_flow[direct_mask] = direct_flow[direct_mask]
curr_coords = grid + accumulated_flow
curr_coords_normed = normalize_coords(curr_coords, h, w) # [b, h, w, 2]
feature_j = features[j].permute(2, 0, 1)[None]
feature_j_sampled = F.grid_sample(feature_j, curr_coords_normed, align_corners=True)
feature_sim = torch.cosine_similarity(feature_i, feature_j_sampled, dim=1).squeeze(0).cpu().numpy()
image_j_sampled = F.grid_sample(images[j], curr_coords_normed, align_corners=True).squeeze()
rgb_sim = torch.norm(images[i] - image_j_sampled, dim=1).squeeze(0).cpu().numpy()
accumulated_cycle_mask *= (feature_sim > 0.5) * (rgb_sim < 0.3)
accumulated_cycle_mask[direct_cycle_mask] = True
accumulated_cycle_mask[direct_occlusion_mask] = False
np.save(os.path.join(out_dir, '{}_{}.npy'.format(imgname_i, imgname_j)), accumulated_flow[0].cpu().numpy())
out_mask = np.concatenate([255 * accumulated_cycle_mask[..., None].astype(np.uint8),
direct_masks[..., 1:]],
axis=-1)
imageio.imwrite('{}/{}_{}.png'.format(out_mask_dir, imgname_i, imgname_j), out_mask)
count_maps[i] += (out_mask / 255).sum(axis=-1).astype(np.uint16)
if not imgname_i in flow_stats.keys():
flow_stats[imgname_i] = {}
flow_stats[imgname_i][imgname_j] = int(np.sum(out_mask/255))
pbar.update(1)
if j == num_imgs - 1:
continue
imgname_j_plus_1 = os.path.basename(img_files[j + 1])
flow_file = os.path.join(scene_dir, 'raft_exhaustive', '{}_{}.npy'.format(imgname_j, imgname_j_plus_1))
curr_flow = np.load(flow_file)
curr_flow = torch.from_numpy(curr_flow).float()[None].cuda() # [b, h, w, 2]
curr_mask_file = flow_file.replace('raft_exhaustive', 'raft_masks').replace('.npy', '.png')
curr_cycle_mask = imageio.imread(curr_mask_file)[..., 0] > 0
flow_curr_sampled = F.grid_sample(curr_flow.permute(0, 3, 1, 2), curr_coords_normed,
align_corners=True).permute(0, 2, 3, 1)
curr_cycle_mask_sampled = F.grid_sample(torch.from_numpy(curr_cycle_mask).float()[None, None].cuda(),
curr_coords_normed, align_corners=True).squeeze().cpu().numpy() == 1
# update
accumulated_flow += flow_curr_sampled
accumulated_cycle_mask *= curr_cycle_mask_sampled
for i in range(num_imgs - 1, 0, -1):
imgname_i = os.path.basename(img_files[i])
imgname_i_minus_1 = os.path.basename(img_files[i - 1])
start_flow_file = os.path.join(scene_dir, 'raft_exhaustive', '{}_{}.npy'.format(imgname_i, imgname_i_minus_1))
start_flow = np.load(start_flow_file)
start_flow = torch.from_numpy(start_flow).float()[None].cuda() # [b, h, w, 2]
start_mask_file = start_flow_file.replace('raft_exhaustive', 'raft_masks').replace('.npy', '.png')
start_cycle_mask = imageio.imread(start_mask_file)[..., 0] > 0
feature_i = features[i].permute(2, 0, 1)[None]
feature_i = F.interpolate(feature_i, size=start_flow.shape[1:3], mode='bilinear')
accumulated_flow = start_flow
accumulated_cycle_mask = start_cycle_mask
for j in range(i - 1, -1, -1):
# # vis
imgname_j = os.path.basename(img_files[j])
direct_flow_file = os.path.join(scene_dir, 'raft_exhaustive', '{}_{}.npy'.format(imgname_i, imgname_j))
direct_flow = np.load(direct_flow_file)
direct_flow = torch.from_numpy(direct_flow).float()[None].cuda() # [b, h, w, 2]
direct_mask_file = direct_flow_file.replace('raft_exhaustive', 'raft_masks').replace('.npy', '.png')
direct_masks = imageio.imread(direct_mask_file)
direct_cycle_mask = direct_masks[..., 0] > 0
direct_occlusion_mask = direct_masks[..., 1] > 0
direct_mask = direct_cycle_mask | direct_occlusion_mask
direct_mask = torch.from_numpy(direct_mask)[None]
accumulated_flow[direct_mask] = direct_flow[direct_mask]
curr_coords = grid + accumulated_flow
curr_coords_normed = normalize_coords(curr_coords, h, w) # [b, h, w, 2]
feature_j = features[j].permute(2, 0, 1)[None]
feature_j_sampled = F.grid_sample(feature_j, curr_coords_normed, align_corners=True)
feature_sim = torch.cosine_similarity(feature_i, feature_j_sampled, dim=1).squeeze(0).cpu().numpy()
image_j_sampled = F.grid_sample(images[j], curr_coords_normed, align_corners=True).squeeze()
rgb_sim = torch.norm(images[i] - image_j_sampled, dim=1).squeeze(0).cpu().numpy()
accumulated_cycle_mask *= (feature_sim > 0.5) * (rgb_sim < 0.3)
accumulated_cycle_mask[direct_cycle_mask] = True
accumulated_cycle_mask[direct_occlusion_mask] = False
np.save(os.path.join(out_dir, '{}_{}.npy'.format(imgname_i, imgname_j)), accumulated_flow[0].cpu().numpy())
out_mask = np.concatenate([255 * accumulated_cycle_mask[..., None].astype(np.uint8),
direct_masks[..., 1:]],
axis=-1)
imageio.imwrite('{}/{}_{}.png'.format(out_mask_dir, imgname_i, imgname_j), out_mask)
count_maps[i] += (out_mask / 255).sum(axis=-1).astype(np.uint16)
if not imgname_i in flow_stats.keys():
flow_stats[imgname_i] = {}
flow_stats[imgname_i][imgname_j] = int(np.sum(out_mask / 255))
pbar.update(1)
if j == 0:
continue
imgname_j_minus_1 = os.path.basename(img_files[j - 1])
flow_file = os.path.join(scene_dir, 'raft_exhaustive', '{}_{}.npy'.format(imgname_j, imgname_j_minus_1))
curr_flow = np.load(flow_file)
curr_flow = torch.from_numpy(curr_flow).float()[None].cuda() # [b, h, w, 2]
curr_mask_file = flow_file.replace('raft_exhaustive', 'raft_masks').replace('.npy', '.png')
curr_cycle_mask = imageio.imread(curr_mask_file)[..., 0] > 0
flow_curr_sampled = F.grid_sample(curr_flow.permute(0, 3, 1, 2), curr_coords_normed,
align_corners=True).permute(0, 2, 3, 1)
curr_cycle_mask_sampled = F.grid_sample(torch.from_numpy(curr_cycle_mask).float()[None, None].cuda(),
curr_coords_normed, align_corners=True).squeeze().cpu().numpy() == 1
# update
accumulated_flow += flow_curr_sampled
accumulated_cycle_mask *= curr_cycle_mask_sampled
pbar.close()
with open(out_flow_stats_file, 'w') as fp:
json.dump(flow_stats, fp)
for i in range(num_imgs):
img_name = os.path.basename(img_files[i])
imageio.imwrite(os.path.join(count_out_dir, img_name.replace('.jpg', '.png')), count_maps[i])
print('chaining raft optical flow for {} is done'.format(scene_dir))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='', help='dataset dir')
args = parser.parse_args()
run(args)
\ No newline at end of file
This diff is collapsed.
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import os
import math
import random
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class FlowDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
class MpiSintel(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
if split == 'test':
self.is_test = True
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class FlyingThings3D(FlowDataset):
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(osp.join(idir, '*.png')) )
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [ flows[i] ]
elif direction == 'into_past':
self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [ flows[i+1] ]
class KITTI(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [ [frame_id] ]
self.image_list += [ [img1, img2] ]
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]]
self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(aug_params, split='training')
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8, device=img.device)
coords1 = coords_grid(N, H//8, W//8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment