Unverified Commit 77940b81 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Moved pfm file reading into dataset utils (#6270)

* Moved pfm file reading into dataset utils

* Made _read_pfm private. Fixed doc format issues.
parent 9effc4cd
import itertools import itertools
import os import os
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
...@@ -10,7 +9,7 @@ import torch ...@@ -10,7 +9,7 @@ import torch
from PIL import Image from PIL import Image
from ..io.image import _read_png_16 from ..io.image import _read_png_16
from .utils import verify_str_arg from .utils import verify_str_arg, _read_pfm
from .vision import VisionDataset from .vision import VisionDataset
...@@ -472,31 +471,3 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): ...@@ -472,31 +471,3 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
# For consistency with other datasets, we convert to numpy # For consistency with other datasets, we convert to numpy
return flow.numpy(), valid_flow_mask.numpy() return flow.numpy(), valid_flow_mask.numpy()
def _read_pfm(file_name):
"""Read flow in .pfm format"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header != b"PF":
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
data = data.reshape(h, w, 3).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:2, :, :]
return data.astype(np.float32)
...@@ -18,6 +18,7 @@ import zipfile ...@@ -18,6 +18,7 @@ import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np
import requests import requests
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -483,3 +484,39 @@ def verify_str_arg( ...@@ -483,3 +484,39 @@ def verify_str_arg(
raise ValueError(msg) raise ValueError(msg)
return value return value
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
file_name (str): Path to the file.
slice_channels (int): Number of channels to slice out of the file.
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header not in [b"PF", b"Pf"]:
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
pfm_channels = 3 if header == b"PF" else 1
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:slice_channels, :, :]
return data.astype(np.float32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment