Commit 6203d7ca authored by huteng.ht's avatar huteng.ht
Browse files

fix(load): set cuda device in each thread


Signed-off-by: default avatarhuteng.ht <huteng.ht@bytedance.com>
parent 8dc44dfb
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "common.h" #include "common.h"
#include "cipher.h" #include "cipher.h"
void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset, void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info); size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info);
size_t get_file_size(const char *file_name, bool use_sfcs_sdk); size_t get_file_size(const char *file_name, bool use_sfcs_sdk);
#endif #endif
\ No newline at end of file
...@@ -43,14 +43,15 @@ class SFCSFile ...@@ -43,14 +43,15 @@ class SFCSFile
SFCSFile(std::string file_path, CipherInfo cipher_info); SFCSFile(std::string file_path, CipherInfo cipher_info);
~SFCSFile(); ~SFCSFile();
size_t get_file_size(); size_t get_file_size();
size_t read_file_parallel(char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset); size_t read_file_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset);
size_t read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread); size_t read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread);
size_t write_file_from_array(pybind11::array_t<char> arr, size_t length); size_t write_file_from_array(pybind11::array_t<char> arr, size_t length);
void delete_file(); void delete_file();
private: private:
size_t read_file(char *addr, size_t length, size_t offset); size_t read_file(char *addr, size_t length, size_t offset);
void read_file_thread(int thread_id, char *addr, char *dev_mem, size_t block_size, size_t total_size, void read_file_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size, size_t total_size,
size_t global_offset); size_t global_offset);
size_t write_file(char *addr, size_t length); size_t write_file(char *addr, size_t length);
}; };
......
...@@ -77,8 +77,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_ ...@@ -77,8 +77,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
{ {
throw std::runtime_error("data ptr does not satisfy the align purpose"); throw std::runtime_error("data ptr does not satisfy the align purpose");
} }
read_file(file_path, (char *)res_tensor.data_ptr(), NULL, 1, *read_unaligned_size, *offset, use_sfcs_sdk, read_file(file_path, (char *)res_tensor.data_ptr(), device_id, NULL, 1, *read_unaligned_size, *offset,
use_direct_io, cipher_info); use_sfcs_sdk, use_direct_io, cipher_info);
*total_size -= *read_unaligned_size; *total_size -= *read_unaligned_size;
*offset += *read_unaligned_size; *offset += *read_unaligned_size;
...@@ -98,8 +98,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_ ...@@ -98,8 +98,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
if ((*offset & (BUF_ALIGN_SIZE - 1)) != 0) if ((*offset & (BUF_ALIGN_SIZE - 1)) != 0)
{ {
size_t read_head_size = min(BUF_ALIGN_SIZE - (*offset & (BUF_ALIGN_SIZE - 1)), *total_size); size_t read_head_size = min(BUF_ALIGN_SIZE - (*offset & (BUF_ALIGN_SIZE - 1)), *total_size);
read_file(file_path, tmp_buf_head, (char *)res_tensor.data_ptr(), 1, read_head_size, *offset, use_sfcs_sdk, read_file(file_path, tmp_buf_head, device_id, (char *)res_tensor.data_ptr(), 1, read_head_size, *offset,
use_direct_io, cipher_info); use_sfcs_sdk, use_direct_io, cipher_info);
*read_unaligned_size = read_head_size; *read_unaligned_size = read_head_size;
*offset += read_head_size; *offset += read_head_size;
*total_size -= read_head_size; *total_size -= read_head_size;
...@@ -109,7 +109,7 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_ ...@@ -109,7 +109,7 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
{ {
size_t tail_offset = end_offset - (end_offset & (BUF_ALIGN_SIZE - 1)); size_t tail_offset = end_offset - (end_offset & (BUF_ALIGN_SIZE - 1));
size_t tensor_offset = tail_offset - *offset + *read_unaligned_size; size_t tensor_offset = tail_offset - *offset + *read_unaligned_size;
read_file(file_path, tmp_buf_tail, (char *)res_tensor.data_ptr() + tensor_offset, 1, read_file(file_path, tmp_buf_tail, device_id, (char *)res_tensor.data_ptr() + tensor_offset, 1,
end_offset - tail_offset, tail_offset, use_sfcs_sdk, use_direct_io, cipher_info); end_offset - tail_offset, tail_offset, use_sfcs_sdk, use_direct_io, cipher_info);
*total_size -= end_offset - tail_offset; *total_size -= end_offset - tail_offset;
} }
...@@ -131,8 +131,8 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens ...@@ -131,8 +131,8 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
{ {
read_unaligned_part(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io, read_unaligned_part(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io,
&read_unaligned_size, cipher_info); &read_unaligned_size, cipher_info);
read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, NULL, num_thread, total_size, offset, read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, device_id, NULL, num_thread,
use_sfcs_sdk, use_direct_io, cipher_info); total_size, offset, use_sfcs_sdk, use_direct_io, cipher_info);
} }
else else
{ {
...@@ -153,8 +153,8 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens ...@@ -153,8 +153,8 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
init_buffer(file_path, total_size, use_pinmem, use_sfcs_sdk); init_buffer(file_path, total_size, use_pinmem, use_sfcs_sdk);
} }
cudaSetDevice(device_id); cudaSetDevice(device_id);
read_file(file_path, pin_mem, (char *)res_tensor.data_ptr() + read_unaligned_size, num_thread, total_size, read_file(file_path, pin_mem, device_id, (char *)res_tensor.data_ptr() + read_unaligned_size, num_thread,
offset, use_sfcs_sdk, use_direct_io, CipherInfo()); total_size, offset, use_sfcs_sdk, use_direct_io, CipherInfo());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
// decrypt with gpu // decrypt with gpu
if (cipher_info.use_cipher && total_size > 0) if (cipher_info.use_cipher && total_size > 0)
......
...@@ -19,8 +19,9 @@ ...@@ -19,8 +19,9 @@
#include "include/fastcrypto.h" #include "include/fastcrypto.h"
#include <errno.h> #include <errno.h>
void read_file_thread_fread(int thread_id, string file_path, char *addr, char *dev_mem, size_t block_size, void read_file_thread_fread(int thread_id, string file_path, char *addr, int device_id, char *dev_mem,
size_t total_size, size_t global_offset, bool use_direct_io, CipherInfo cipher_info) size_t block_size, size_t total_size, size_t global_offset, bool use_direct_io,
CipherInfo cipher_info)
{ {
size_t offset = thread_id * block_size; size_t offset = thread_id * block_size;
size_t read_size = block_size; size_t read_size = block_size;
...@@ -91,12 +92,15 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d ...@@ -91,12 +92,15 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d
} }
} }
if (dev_mem != NULL) if (dev_mem != NULL && device_id >= 0)
{
cudaSetDevice(device_id);
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice); cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
}
} }
void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset, void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info) size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info)
{ {
if (total_size == 0) if (total_size == 0)
{ {
...@@ -114,14 +118,14 @@ void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size ...@@ -114,14 +118,14 @@ void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size
if (use_sfcs_sdk) if (use_sfcs_sdk)
{ {
SFCSFile sfcs_file(file_path, cipher_info); SFCSFile sfcs_file(file_path, cipher_info);
sfcs_file.read_file_parallel(addr, dev_mem, num_thread, total_size, global_offset); sfcs_file.read_file_parallel(addr, device_id, dev_mem, num_thread, total_size, global_offset);
} }
else else
{ {
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
{ {
threads[thread_id] = std::thread(read_file_thread_fread, thread_id, file_path, addr, dev_mem, block_size, threads[thread_id] = std::thread(read_file_thread_fread, thread_id, file_path, addr, device_id, dev_mem,
total_size, global_offset, use_direct_io, cipher_info); block_size, total_size, global_offset, use_direct_io, cipher_info);
} }
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
......
...@@ -131,8 +131,8 @@ size_t SFCSFile::read_file(char *addr, size_t length, size_t offset) ...@@ -131,8 +131,8 @@ size_t SFCSFile::read_file(char *addr, size_t length, size_t offset)
return length - count; return length - count;
} }
void SFCSFile::read_file_thread(int thread_id, char *addr, char *dev_mem, size_t block_size, size_t total_size, void SFCSFile::read_file_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size,
size_t global_offset) size_t total_size, size_t global_offset)
{ {
size_t offset = thread_id * block_size; size_t offset = thread_id * block_size;
size_t read_size = block_size; size_t read_size = block_size;
...@@ -145,11 +145,15 @@ void SFCSFile::read_file_thread(int thread_id, char *addr, char *dev_mem, size_t ...@@ -145,11 +145,15 @@ void SFCSFile::read_file_thread(int thread_id, char *addr, char *dev_mem, size_t
// TODO: actual number of bytes read may be less than read_size // TODO: actual number of bytes read may be less than read_size
read_file(addr + offset, read_size, global_offset + offset); read_file(addr + offset, read_size, global_offset + offset);
if (dev_mem != NULL) if (dev_mem != NULL && device_id >= 0)
{
cudaSetDevice(device_id);
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice); cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
}
} }
size_t SFCSFile::read_file_parallel(char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset) size_t SFCSFile::read_file_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset)
{ {
vector<thread> threads(num_thread); vector<thread> threads(num_thread);
...@@ -166,8 +170,8 @@ size_t SFCSFile::read_file_parallel(char *addr, char *dev_mem, int num_thread, s ...@@ -166,8 +170,8 @@ size_t SFCSFile::read_file_parallel(char *addr, char *dev_mem, int num_thread, s
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
{ {
threads[thread_id] = std::thread(&SFCSFile::read_file_thread, this, thread_id, addr, dev_mem, block_size, threads[thread_id] = std::thread(&SFCSFile::read_file_thread, this, thread_id, addr, device_id, dev_mem,
total_size, global_offset); block_size, total_size, global_offset);
} }
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
...@@ -182,7 +186,7 @@ size_t SFCSFile::read_file_to_array(pybind11::array_t<char> arr, size_t length, ...@@ -182,7 +186,7 @@ size_t SFCSFile::read_file_to_array(pybind11::array_t<char> arr, size_t length,
{ {
pybind11::buffer_info buf_info = arr.request(); pybind11::buffer_info buf_info = arr.request();
char *addr = static_cast<char *>(buf_info.ptr); char *addr = static_cast<char *>(buf_info.ptr);
return read_file_parallel(addr, NULL, num_thread, length, offset); return read_file_parallel(addr, -1, NULL, num_thread, length, offset);
} }
size_t SFCSFile::write_file(char *addr, size_t length) size_t SFCSFile::write_file(char *addr, size_t length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment