Unverified Commit a9e4cea0 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add read_file (#2728)

* Add read_file

* Add test for non-existent file

* Fix lint

* Lint v2

* Try fix windows

* Try fix Windows v2

* Lint

* Windows v3

* Missed one change in the adapted function

* Try again on Windows

* One more try on Windows

* Give up on tempfile for now

* Are extensions what's missing on Windows?

* Investigating if the issue is on our side

* Try deleting tensor which could hold into the file

* Put back temporary folder
parent 217e26fc
......@@ -8,10 +8,13 @@ import torch
import torchvision
from PIL import Image
from torchvision.io.image import (
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, _read_file,
read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png)
import numpy as np
from common_utils import get_tmp_dir
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
......@@ -206,15 +209,35 @@ class ImageTester(unittest.TestCase):
for img_path in get_images(IMAGE_ROOT, ".jpg"):
img_pil = torch.load(img_path.replace('jpg', 'pth'))
img_pil = img_pil.permute(2, 0, 1)
img_ljpeg = decode_image(_read_file(img_path))
img_ljpeg = decode_image(read_file(img_path))
self.assertTrue(img_ljpeg.equal(img_pil))
for img_path in get_images(IMAGE_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
img_lpng = decode_image(_read_file(img_path))
img_lpng = decode_image(read_file(img_path))
self.assertTrue(img_lpng.equal(img_pil))
def test_read_file(self):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
with open(fpath, 'wb') as f:
f.write(content)
data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
self.assertTrue(data.equal(expected))
# Windows holds into the file until the tensor is alive
# so need to del the tensor before deleting the file see
# https://github.com/pytorch/vision/issues/2743#issuecomment-703817293
del data
os.unlink(fpath)
with self.assertRaisesRegex(
RuntimeError, "No such file or directory: 'tst'"):
read_file('tst')
if __name__ == '__main__':
unittest.main()
......@@ -19,4 +19,5 @@ static auto registry = torch::RegisterOperators()
.op("image::decode_jpeg", &decodeJPEG)
.op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG)
.op("image::read_file", &read_file)
.op("image::decode_image", &decode_image);
......@@ -4,6 +4,7 @@
#include <torch/script.h>
#include <torch/torch.h>
#include "read_image_cpu.h"
#include "read_write_file_cpu.h"
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include "writejpeg_cpu.h"
......
#include "read_write_file_cpu.h"
torch::Tensor read_file(std::string filename) {
// CHECK if this only works on Windows for files smaller than 2GB
// https://stackoverflow.com/questions/5840148/how-can-i-get-a-files-size-in-c
struct stat stat_buf;
int rc = stat(filename.c_str(), &stat_buf);
// errno is a variable defined in errno.h
TORCH_CHECK(
rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'");
int64_t size = stat_buf.st_size;
TORCH_CHECK(size > 0, "Expected a non empty file");
auto data =
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8);
return data;
}
#pragma once
#include <errno.h>
#include <sys/stat.h>
#include <torch/torch.h>
C10_EXPORT torch::Tensor read_file(std::string filename);
......@@ -23,14 +23,18 @@ except (ImportError, OSError):
pass
def _read_file(path: str) -> torch.Tensor:
if not os.path.isfile(path):
raise ValueError("Expected a valid file path.")
size = os.path.getsize(path)
if size == 0:
raise ValueError("Expected a non empty file.")
data = torch.from_file(path, dtype=torch.uint8, size=size)
def read_file(path: str) -> torch.Tensor:
"""
Reads and outputs the bytes contents of a file as a uint8 Tensor
with one dimension.
Arguments:
path (str): the path to the file to be read
Returns:
data (Tensor)
"""
data = torch.ops.image.read_file(path)
return data
......@@ -61,7 +65,7 @@ def read_png(path: str) -> torch.Tensor:
Returns:
output (Tensor[3, image_height, image_width])
"""
data = _read_file(path)
data = read_file(path)
return decode_png(data)
......@@ -119,7 +123,7 @@ def read_jpeg(path: str) -> torch.Tensor:
Returns:
output (Tensor[3, image_height, image_width])
"""
data = _read_file(path)
data = read_file(path)
return decode_jpeg(data)
......@@ -187,5 +191,5 @@ def read_image(path: str) -> torch.Tensor:
Returns:
output (Tensor[3, image_height, image_width])
"""
data = _read_file(path)
data = read_file(path)
return decode_image(data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment