Unverified Commit f5c0bfa5 authored by peterjc123's avatar peterjc123 Committed by GitHub
Browse files

Make read_file and write_file accept unicode strings on Windows (#2949)

* Make read_file accept unicode strings on Windows

* More fixes

* Remove definitions from source files

* Move string definitions to header

* Add checks

* Fix comments

* Update macro

* Fix comments

* Fix lint

* include windows header

* Change func signature in header

* Use from_blob

* Fix fread calls

* Fix clang format

* Fix missing return

* Avoid copy
parent 5b61a5c8
...@@ -221,6 +221,18 @@ class ImageTester(unittest.TestCase): ...@@ -221,6 +221,18 @@ class ImageTester(unittest.TestCase):
RuntimeError, "No such file or directory: 'tst'"): RuntimeError, "No such file or directory: 'tst'"):
read_file('tst') read_file('tst')
def test_read_file_non_ascii(self):
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
with open(fpath, 'wb') as f:
f.write(content)
data = read_file(fpath)
expected = torch.tensor(list(content), dtype=torch.uint8)
self.assertTrue(data.equal(expected))
os.unlink(fpath)
def test_write_file(self): def test_write_file(self):
with get_tmp_dir() as d: with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = 'test1.bin', b'TorchVision\211\n'
...@@ -233,6 +245,18 @@ class ImageTester(unittest.TestCase): ...@@ -233,6 +245,18 @@ class ImageTester(unittest.TestCase):
self.assertEqual(content, saved_content) self.assertEqual(content, saved_content)
os.unlink(fpath) os.unlink(fpath)
def test_write_file_non_ascii(self):
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
with open(fpath, 'rb') as f:
saved_content = f.read()
self.assertEqual(content, saved_content)
os.unlink(fpath)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
#include "read_write_file_cpu.h" #include "read_write_file_cpu.h"
// According to
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
// we should use _stat64 for 64-bit file size on Windows.
#ifdef _WIN32 #ifdef _WIN32
#define VISION_STAT _stat64 #define WIN32_LEAN_AND_MEAN
#else #include <Windows.h>
#define VISION_STAT stat
std::wstring utf8_decode(const std::string& str) {
if (str.empty()) {
return std::wstring();
}
int size_needed = MultiByteToWideChar(
CP_UTF8, 0, str.c_str(), static_cast<int>(str.size()), NULL, 0);
TORCH_CHECK(size_needed > 0, "Error converting the content to Unicode");
std::wstring wstrTo(size_needed, 0);
MultiByteToWideChar(
CP_UTF8,
0,
str.c_str(),
static_cast<int>(str.size()),
&wstrTo[0],
size_needed);
return wstrTo;
}
#endif #endif
torch::Tensor read_file(std::string filename) { torch::Tensor read_file(const std::string& filename) {
struct VISION_STAT stat_buf; #ifdef _WIN32
int rc = VISION_STAT(filename.c_str(), &stat_buf); // According to
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
// we should use struct __stat64 and _wstat64 for 64-bit file size on Windows.
struct __stat64 stat_buf;
auto fileW = utf8_decode(filename);
int rc = _wstat64(fileW.c_str(), &stat_buf);
#else
struct stat stat_buf;
int rc = stat(filename.c_str(), &stat_buf);
#endif
// errno is a variable defined in errno.h // errno is a variable defined in errno.h
TORCH_CHECK( TORCH_CHECK(
rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'"); rc == 0, "[Errno ", errno, "] ", strerror(errno), ": '", filename, "'");
...@@ -21,9 +44,20 @@ torch::Tensor read_file(std::string filename) { ...@@ -21,9 +44,20 @@ torch::Tensor read_file(std::string filename) {
TORCH_CHECK(size > 0, "Expected a non empty file"); TORCH_CHECK(size > 0, "Expected a non empty file");
#ifdef _WIN32 #ifdef _WIN32
auto data = // TODO: Once torch::from_file handles UTF-8 paths correctly, we should move
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8) // back to use the following implementation since it uses file mapping.
.clone(); // auto data =
// torch::from_file(filename, /*shared=*/false, /*size=*/size,
// torch::kU8).clone()
FILE* infile = _wfopen(fileW.c_str(), L"rb");
TORCH_CHECK(infile != nullptr, "Error opening input file");
auto data = torch::empty({size}, torch::kU8);
auto dataBytes = data.data_ptr<uint8_t>();
fread(dataBytes, sizeof(uint8_t), size, infile);
fclose(infile);
#else #else
auto data = auto data =
torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8); torch::from_file(filename, /*shared=*/false, /*size=*/size, torch::kU8);
...@@ -32,7 +66,7 @@ torch::Tensor read_file(std::string filename) { ...@@ -32,7 +66,7 @@ torch::Tensor read_file(std::string filename) {
return data; return data;
} }
void write_file(std::string filename, torch::Tensor& data) { void write_file(const std::string& filename, torch::Tensor& data) {
// Check that the input tensor is on CPU // Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU"); TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
...@@ -44,7 +78,12 @@ void write_file(std::string filename, torch::Tensor& data) { ...@@ -44,7 +78,12 @@ void write_file(std::string filename, torch::Tensor& data) {
auto fileBytes = data.data_ptr<uint8_t>(); auto fileBytes = data.data_ptr<uint8_t>();
auto fileCStr = filename.c_str(); auto fileCStr = filename.c_str();
#ifdef _WIN32
auto fileW = utf8_decode(filename);
FILE* outfile = _wfopen(fileW.c_str(), L"wb");
#else
FILE* outfile = fopen(fileCStr, "wb"); FILE* outfile = fopen(fileCStr, "wb");
#endif
TORCH_CHECK(outfile != nullptr, "Error opening output file"); TORCH_CHECK(outfile != nullptr, "Error opening output file");
......
...@@ -4,6 +4,6 @@ ...@@ -4,6 +4,6 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <torch/torch.h> #include <torch/torch.h>
C10_EXPORT torch::Tensor read_file(std::string filename); C10_EXPORT torch::Tensor read_file(const std::string& filename);
C10_EXPORT void write_file(std::string filename, torch::Tensor& data); C10_EXPORT void write_file(const std::string& filename, torch::Tensor& data);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment