Unverified Commit 77940b81 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Moved pfm file reading into dataset utils (#6270)

* Moved pfm file reading into dataset utils

* Made _read_pfm private. Fixed doc format issues.
parent 9effc4cd
import itertools
import os
import re
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
......@@ -10,7 +9,7 @@ import torch
from PIL import Image
from ..io.image import _read_png_16
from .utils import verify_str_arg
from .utils import verify_str_arg, _read_pfm
from .vision import VisionDataset
......@@ -472,31 +471,3 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
# For consistency with other datasets, we convert to numpy
return flow.numpy(), valid_flow_mask.numpy()
def _read_pfm(file_name):
"""Read flow in .pfm format"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header != b"PF":
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
data = data.reshape(h, w, 3).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:2, :, :]
return data.astype(np.float32)
......@@ -18,6 +18,7 @@ import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import numpy as np
import requests
import torch
from torch.utils.model_zoo import tqdm
......@@ -483,3 +484,39 @@ def verify_str_arg(
raise ValueError(msg)
return value
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
file_name (str): Path to the file.
slice_channels (int): Number of channels to slice out of the file.
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header not in [b"PF", b"Pf"]:
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
pfm_channels = 3 if header == b"PF" else 1
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:slice_channels, :, :]
return data.astype(np.float32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment