Commit b1809ef9 authored by 王孝诚's avatar 王孝诚 Committed by huteng.ht
Browse files

fix: directIO segmentation fault

* feat: add test case for fall back
* fix: fix lint
* fix: fix lint
* feat: add ret check for file operation
* fix: fix VeTurboIO to veTurboIO
parent e5edc542
...@@ -184,3 +184,11 @@ class TestLoad(TestCase): ...@@ -184,3 +184,11 @@ class TestLoad(TestCase):
self.cuda_tensors_0, self.pt_filepath_enc_h, "cuda:0", use_cipher=True, enable_fast_mode=False self.cuda_tensors_0, self.pt_filepath_enc_h, "cuda:0", use_cipher=True, enable_fast_mode=False
) )
del os.environ["VETURBOIO_CIPHER_HEADER"] del os.environ["VETURBOIO_CIPHER_HEADER"]
def test_load_directIO_fall_back(self):
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpFile:
veturboio.save_file(self.tensors_0, tmpFile.file.name)
tmpFile.flush()
loaded_tensors = veturboio.load(tmpFile.name, map_location="cpu", use_direct_io=True)
for key in self.tensors_0.keys():
self.assertTrue(torch.allclose(self.tensors_0[key], loaded_tensors[key]))
...@@ -14,15 +14,20 @@ ...@@ -14,15 +14,20 @@
* limitations under the License. * limitations under the License.
*/ */
#include "include/load_utils.h" #include "include/load_utils.h"
#include "include/logging.h"
#include "include/cipher.h" #include "include/cipher.h"
#include "include/fastcrypto.h" #include "include/fastcrypto.h"
#include <errno.h>
void read_file_thread_fread(int thread_id, string file_path, char *addr, char *dev_mem, size_t block_size, void read_file_thread_fread(int thread_id, string file_path, char *addr, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset, bool use_direct_io, CipherInfo cipher_info) size_t total_size, size_t global_offset, bool use_direct_io, CipherInfo cipher_info)
{ {
size_t offset = thread_id * block_size; size_t offset = thread_id * block_size;
size_t read_size = block_size; size_t read_size = block_size;
int fd; int fd = -1;
int ret = 0;
size_t size_read = 0;
if (offset + read_size >= total_size) if (offset + read_size >= total_size)
{ {
read_size = (total_size > offset) ? total_size - offset : 0; read_size = (total_size > offset) ? total_size - offset : 0;
...@@ -30,16 +35,48 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d ...@@ -30,16 +35,48 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d
// TODO: use_direct_io if sfcs file detected // TODO: use_direct_io if sfcs file detected
if (use_direct_io) if (use_direct_io)
{ {
fd = open(file_path.c_str(), O_RDONLY | O_DIRECT); if ((fd = open(file_path.c_str(), O_RDONLY | O_DIRECT)) < 0)
{
if (errno == EINVAL)
{
logWarn("open file using directIO failed, fall back to bufferIO", file_path.c_str(),
std::strerror(EINVAL));
}
else
{
logError("open file using directIO failed", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply open operation");
}
}
} }
else if (fd == -1)
{ {
fd = open(file_path.c_str(), O_RDONLY); if ((fd = open(file_path.c_str(), O_RDONLY)) < 0)
{
logError("open file using bufferIO failed", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply open operation");
}
} }
FILE *fp = fdopen(fd, "rb"); FILE *fp = fdopen(fd, "rb");
fseek(fp, global_offset + offset, SEEK_SET); if (fp == NULL)
fread(addr + offset, 1, read_size, fp); {
fclose(fp); logError("can't apply fdopen to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fdopen operation");
}
if ((ret = fseek(fp, global_offset + offset, SEEK_SET)) < 0)
{
logError("can't apply fseek to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fseek operation");
}
if ((size_read = fread(addr + offset, 1, read_size, fp)) == 0)
{
logWarn("read file with 0 bytes returned", file_path.c_str(), offset, read_size);
}
if ((ret = fclose(fp)) < 0)
{
logError("can't apply fclose to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fclose operation");
}
// Decrypt if use_cipher is true // Decrypt if use_cipher is true
if (cipher_info.use_cipher) if (cipher_info.use_cipher)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment