"vscode:/vscode.git/clone" did not exist on "12590fdccebb34f39fb85b7dae29b80fade2b6b0"
Unverified Commit 6e639d3e authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add write_file (#2765)

* Add write_file

* Fix lint
parent 635406c3
......@@ -9,7 +9,7 @@ import torchvision
from PIL import Image
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png)
encode_png, write_png, write_file)
import numpy as np
from common_utils import get_tmp_dir
......@@ -225,6 +225,18 @@ class ImageTester(unittest.TestCase):
RuntimeError, "No such file or directory: 'tst'"):
read_file('tst')
def test_write_file(self):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)
if __name__ == '__main__':
unittest.main()
......@@ -20,4 +20,5 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encodeJPEG)
.op("image::write_jpeg", &writeJPEG)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
......@@ -18,3 +18,23 @@ torch::Tensor read_file(std::string filename) {
return data;
}
void write_file(std::string filename, torch::Tensor& data) {
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 1, "Input data should be a 1-dimensional tensor");
auto fileBytes = data.data_ptr<uint8_t>();
auto fileCStr = filename.c_str();
FILE* outfile = fopen(fileCStr, "wb");
TORCH_CHECK(outfile != NULL, "Error opening output file");
fwrite(fileBytes, sizeof(uint8_t), data.numel(), outfile);
fclose(outfile);
}
......@@ -5,3 +5,5 @@
#include <torch/torch.h>
C10_EXPORT torch::Tensor read_file(std::string filename);
C10_EXPORT void write_file(std::string filename, torch::Tensor& data);
......@@ -38,6 +38,18 @@ def read_file(path: str) -> torch.Tensor:
return data
def write_file(filename: str, data: torch.Tensor) -> None:
"""
Writes the contents of a uint8 tensor with one dimension to a
file.
Arguments:
filename (str): the path to the file to be written
data (Tensor): the contents to be written to the output file
"""
torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB Tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment