Unverified Commit 694949ed authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Open Fix .flo file decoding (#4870)

parent 6d9a42c3
......@@ -1914,11 +1914,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
def test_flow(self):
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
h, w = self.FLOW_H, self.FLOW_W
expected_flow = np.arange(2 * h * w).reshape(h, w, 2).transpose(2, 0, 1)
with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
assert flow.shape == (2, h, w)
np.testing.assert_allclose(flow, expected_flow)
# Make sure flow is always None for test split
with self.create_dataset(split="test") as (dataset, _):
......@@ -2041,11 +2043,14 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
def test_flow(self, config):
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
# Also make sure the flow is properly decoded
h, w = self.FLOW_H, self.FLOW_W
expected_flow = np.arange(2 * h * w).reshape(h, w, 2).transpose(2, 0, 1)
with self.create_dataset(config=config) as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
assert flow.shape == (2, h, w)
np.testing.assert_allclose(flow, expected_flow)
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
......@@ -2095,11 +2100,16 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
@datasets_utils.test_all_configs
def test_flow(self, config):
h, w = self.FLOW_H, self.FLOW_W
expected_flow = np.arange(3 * h * w).reshape(h, w, 3).transpose(2, 0, 1)
expected_flow = np.flip(expected_flow, axis=1)
expected_flow = expected_flow[:2, :, :]
with self.create_dataset(config=config) as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
# We don't check the values because the reshaping and flipping makes it hard to figure out
np.testing.assert_allclose(flow, expected_flow)
def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
......
......@@ -376,7 +376,7 @@ def _read_flo(file_name):
w = int(np.fromfile(f, "<i4", count=1))
h = int(np.fromfile(f, "<i4", count=1))
data = np.fromfile(f, "<f4", count=2 * w * h)
return data.reshape(2, h, w)
return data.reshape(h, w, 2).transpose(2, 0, 1)
def _read_16bits_png_with_flow_and_valid_mask(file_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment