Unverified Commit 6e639d3e authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add write_file (#2765)

* Add write_file

* Fix lint
parent 635406c3
...@@ -9,7 +9,7 @@ import torchvision ...@@ -9,7 +9,7 @@ import torchvision
from PIL import Image from PIL import Image
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png) encode_png, write_png, write_file)
import numpy as np import numpy as np
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
...@@ -225,6 +225,18 @@ class ImageTester(unittest.TestCase): ...@@ -225,6 +225,18 @@ class ImageTester(unittest.TestCase):
RuntimeError, "No such file or directory: 'tst'"): RuntimeError, "No such file or directory: 'tst'"):
read_file('tst') read_file('tst')
def test_write_file(self):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,4 +20,5 @@ static auto registry = torch::RegisterOperators() ...@@ -20,4 +20,5 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encodeJPEG) .op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG) .op("image::write_jpeg", &writeJPEG)
.op("image::read_file", &read_file) .op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image); .op("image::decode_image", &decode_image);
...@@ -18,3 +18,23 @@ torch::Tensor read_file(std::string filename) { ...@@ -18,3 +18,23 @@ torch::Tensor read_file(std::string filename) {
return data; return data;
} }
void write_file(std::string filename, torch::Tensor& data) {
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor");
auto fileBytes = data.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");
TORCH_CHECK(outfile != NULL, "Error opening output file");
fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile);
}
...@@ -5,3 +5,5 @@ ...@@ -5,3 +5,5 @@
#include <torch/torch.h> #include <torch/torch.h>
C10_EXPORT torch::Tensor read_file(std::string filename); C10_EXPORT torch::Tensor read_file(std::string filename);
C10_EXPORT void write_file(std::string filename, torch::Tensor& data);
...@@ -38,6 +38,18 @@ def read_file(path: str) -> torch.Tensor: ...@@ -38,6 +38,18 @@ def read_file(path: str) -> torch.Tensor:
return data return data
def write_file(filename: str, data: torch.Tensor) -> None:
"""
Writes the contents of a uint8 tensor with one dimension to a
file.
Arguments:
filename (str): the path to the file to be written
data (Tensor): the contents to be written to the output file
"""
torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor) -> torch.Tensor: def decode_png(input: torch.Tensor) -> torch.Tensor:
""" """
Decodes a PNG image into a 3 dimensional RGB Tensor. Decodes a PNG image into a 3 dimensional RGB Tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment