Commit b1809ef9 authored by 王孝诚's avatar 王孝诚 Committed by huteng.ht
Browse files

fix: directIO segmentation fault

* feat: add test case for fall back
* fix: fix lint
* fix: fix lint
* feat: add ret check for file operation
* fix: fix VeTurboIO to veTurboIO
parent e5edc542
......@@ -184,3 +184,11 @@ class TestLoad(TestCase):
self.cuda_tensors_0, self.pt_filepath_enc_h, "cuda:0", use_cipher=True, enable_fast_mode=False
)
del os.environ["VETURBOIO_CIPHER_HEADER"]
def test_load_directIO_fall_back(self):
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpFile:
veturboio.save_file(self.tensors_0, tmpFile.file.name)
tmpFile.flush()
loaded_tensors = veturboio.load(tmpFile.name, map_location="cpu", use_direct_io=True)
for key in self.tensors_0.keys():
self.assertTrue(torch.allclose(self.tensors_0[key], loaded_tensors[key]))
......@@ -14,15 +14,20 @@
* limitations under the License.
*/
#include "include/load_utils.h"
#include "include/logging.h"
#include "include/cipher.h"
#include "include/fastcrypto.h"
#include <errno.h>
void read_file_thread_fread(int thread_id, string file_path, char *addr, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset, bool use_direct_io, CipherInfo cipher_info)
{
size_t offset = thread_id * block_size;
size_t read_size = block_size;
int fd;
int fd = -1;
int ret = 0;
size_t size_read = 0;
if (offset + read_size >= total_size)
{
read_size = (total_size > offset) ? total_size - offset : 0;
......@@ -30,16 +35,48 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d
// TODO: use_direct_io if sfcs file detected
if (use_direct_io)
{
fd = open(file_path.c_str(), O_RDONLY | O_DIRECT);
if ((fd = open(file_path.c_str(), O_RDONLY | O_DIRECT)) < 0)
{
if (errno == EINVAL)
{
logWarn("open file using directIO failed, fall back to bufferIO", file_path.c_str(),
std::strerror(EINVAL));
}
else
{
logError("open file using directIO failed", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply open operation");
}
}
}
else
if (fd == -1)
{
fd = open(file_path.c_str(), O_RDONLY);
if ((fd = open(file_path.c_str(), O_RDONLY)) < 0)
{
logError("open file using bufferIO failed", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply open operation");
}
}
FILE *fp = fdopen(fd, "rb");
fseek(fp, global_offset + offset, SEEK_SET);
fread(addr + offset, 1, read_size, fp);
fclose(fp);
if (fp == NULL)
{
logError("can't apply fdopen to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fdopen operation");
}
if ((ret = fseek(fp, global_offset + offset, SEEK_SET)) < 0)
{
logError("can't apply fseek to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fseek operation");
}
if ((size_read = fread(addr + offset, 1, read_size, fp)) == 0)
{
logWarn("read file with 0 bytes returned", file_path.c_str(), offset, read_size);
}
if ((ret = fclose(fp)) < 0)
{
logError("can't apply fclose to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fclose operation");
}
// Decrypt if use_cipher is true
if (cipher_info.use_cipher)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment