Commit 6203d7ca authored by huteng.ht's avatar huteng.ht
Browse files

fix(load): set cuda device in each thread


Signed-off-by: default avatarhuteng.ht <huteng.ht@bytedance.com>
parent 8dc44dfb
......@@ -19,8 +19,8 @@
#include "common.h"
#include "cipher.h"
void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset,
bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info);
void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info);
size_t get_file_size(const char *file_name, bool use_sfcs_sdk);
#endif
\ No newline at end of file
......@@ -43,14 +43,15 @@ class SFCSFile
SFCSFile(std::string file_path, CipherInfo cipher_info);
~SFCSFile();
size_t get_file_size();
size_t read_file_parallel(char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset);
size_t read_file_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset);
size_t read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread);
size_t write_file_from_array(pybind11::array_t<char> arr, size_t length);
void delete_file();
private:
size_t read_file(char *addr, size_t length, size_t offset);
void read_file_thread(int thread_id, char *addr, char *dev_mem, size_t block_size, size_t total_size,
void read_file_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size, size_t total_size,
size_t global_offset);
size_t write_file(char *addr, size_t length);
};
......
......@@ -77,8 +77,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
{
throw std::runtime_error("data ptr does not satisfy the align purpose");
}
read_file(file_path, (char *)res_tensor.data_ptr(), NULL, 1, *read_unaligned_size, *offset, use_sfcs_sdk,
use_direct_io, cipher_info);
read_file(file_path, (char *)res_tensor.data_ptr(), device_id, NULL, 1, *read_unaligned_size, *offset,
use_sfcs_sdk, use_direct_io, cipher_info);
*total_size -= *read_unaligned_size;
*offset += *read_unaligned_size;
......@@ -98,8 +98,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
if ((*offset & (BUF_ALIGN_SIZE - 1)) != 0)
{
size_t read_head_size = min(BUF_ALIGN_SIZE - (*offset & (BUF_ALIGN_SIZE - 1)), *total_size);
read_file(file_path, tmp_buf_head, (char *)res_tensor.data_ptr(), 1, read_head_size, *offset, use_sfcs_sdk,
use_direct_io, cipher_info);
read_file(file_path, tmp_buf_head, device_id, (char *)res_tensor.data_ptr(), 1, read_head_size, *offset,
use_sfcs_sdk, use_direct_io, cipher_info);
*read_unaligned_size = read_head_size;
*offset += read_head_size;
*total_size -= read_head_size;
......@@ -109,7 +109,7 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
{
size_t tail_offset = end_offset - (end_offset & (BUF_ALIGN_SIZE - 1));
size_t tensor_offset = tail_offset - *offset + *read_unaligned_size;
read_file(file_path, tmp_buf_tail, (char *)res_tensor.data_ptr() + tensor_offset, 1,
read_file(file_path, tmp_buf_tail, device_id, (char *)res_tensor.data_ptr() + tensor_offset, 1,
end_offset - tail_offset, tail_offset, use_sfcs_sdk, use_direct_io, cipher_info);
*total_size -= end_offset - tail_offset;
}
......@@ -131,8 +131,8 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
{
read_unaligned_part(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io,
&read_unaligned_size, cipher_info);
read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, NULL, num_thread, total_size, offset,
use_sfcs_sdk, use_direct_io, cipher_info);
read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, device_id, NULL, num_thread,
total_size, offset, use_sfcs_sdk, use_direct_io, cipher_info);
}
else
{
......@@ -153,8 +153,8 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
init_buffer(file_path, total_size, use_pinmem, use_sfcs_sdk);
}
cudaSetDevice(device_id);
read_file(file_path, pin_mem, (char *)res_tensor.data_ptr() + read_unaligned_size, num_thread, total_size,
offset, use_sfcs_sdk, use_direct_io, CipherInfo());
read_file(file_path, pin_mem, device_id, (char *)res_tensor.data_ptr() + read_unaligned_size, num_thread,
total_size, offset, use_sfcs_sdk, use_direct_io, CipherInfo());
cudaDeviceSynchronize();
// decrypt with gpu
if (cipher_info.use_cipher && total_size > 0)
......
......@@ -19,8 +19,9 @@
#include "include/fastcrypto.h"
#include <errno.h>
void read_file_thread_fread(int thread_id, string file_path, char *addr, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset, bool use_direct_io, CipherInfo cipher_info)
void read_file_thread_fread(int thread_id, string file_path, char *addr, int device_id, char *dev_mem,
size_t block_size, size_t total_size, size_t global_offset, bool use_direct_io,
CipherInfo cipher_info)
{
size_t offset = thread_id * block_size;
size_t read_size = block_size;
......@@ -91,12 +92,15 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d
}
}
if (dev_mem != NULL)
if (dev_mem != NULL && device_id >= 0)
{
cudaSetDevice(device_id);
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
}
}
void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset,
bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info)
void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info)
{
if (total_size == 0)
{
......@@ -114,14 +118,14 @@ void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size
if (use_sfcs_sdk)
{
SFCSFile sfcs_file(file_path, cipher_info);
sfcs_file.read_file_parallel(addr, dev_mem, num_thread, total_size, global_offset);
sfcs_file.read_file_parallel(addr, device_id, dev_mem, num_thread, total_size, global_offset);
}
else
{
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id] = std::thread(read_file_thread_fread, thread_id, file_path, addr, dev_mem, block_size,
total_size, global_offset, use_direct_io, cipher_info);
threads[thread_id] = std::thread(read_file_thread_fread, thread_id, file_path, addr, device_id, dev_mem,
block_size, total_size, global_offset, use_direct_io, cipher_info);
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
......
......@@ -131,8 +131,8 @@ size_t SFCSFile::read_file(char *addr, size_t length, size_t offset)
return length - count;
}
void SFCSFile::read_file_thread(int thread_id, char *addr, char *dev_mem, size_t block_size, size_t total_size,
size_t global_offset)
void SFCSFile::read_file_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset)
{
size_t offset = thread_id * block_size;
size_t read_size = block_size;
......@@ -145,11 +145,15 @@ void SFCSFile::read_file_thread(int thread_id, char *addr, char *dev_mem, size_t
// TODO: actual number of bytes read may be less than read_size
read_file(addr + offset, read_size, global_offset + offset);
if (dev_mem != NULL)
if (dev_mem != NULL && device_id >= 0)
{
cudaSetDevice(device_id);
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
}
}
size_t SFCSFile::read_file_parallel(char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset)
size_t SFCSFile::read_file_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset)
{
vector<thread> threads(num_thread);
......@@ -166,8 +170,8 @@ size_t SFCSFile::read_file_parallel(char *addr, char *dev_mem, int num_thread, s
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id] = std::thread(&SFCSFile::read_file_thread, this, thread_id, addr, dev_mem, block_size,
total_size, global_offset);
threads[thread_id] = std::thread(&SFCSFile::read_file_thread, this, thread_id, addr, device_id, dev_mem,
block_size, total_size, global_offset);
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
......@@ -182,7 +186,7 @@ size_t SFCSFile::read_file_to_array(pybind11::array_t<char> arr, size_t length,
{
pybind11::buffer_info buf_info = arr.request();
char *addr = static_cast<char *>(buf_info.ptr);
return read_file_parallel(addr, NULL, num_thread, length, offset);
return read_file_parallel(addr, -1, NULL, num_thread, length, offset);
}
size_t SFCSFile::write_file(char *addr, size_t length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment