"vscode:/vscode.git/clone" did not exist on "6cea0360276e5fc7e2fecbe0cadf89cc72615279"
Commit 0ed05516 authored by huteng.ht's avatar huteng.ht
Browse files

feat: upgrade to sdk v1 latest version



* 70b2701 on master
Signed-off-by: default avatarhuteng.ht <huteng.ht@bytedance.com>
parent 61d052cb
Pipeline #2963 canceled with stages
...@@ -75,11 +75,17 @@ class SnapdAdapter(HTTPAdapter): ...@@ -75,11 +75,17 @@ class SnapdAdapter(HTTPAdapter):
class DataPipeClient: class DataPipeClient:
DATAPIPE_SOCKET_PATH = os.getenv('DATAPIPE_SOCKET_PATH', '/finetuned-model/datapipe.sock') DATAPIPE_SOCKET_PATH = os.getenv('DATAPIPE_SOCKET_PATH', '/finetuned-model/datapipe.sock')
PING_HEADER = {'X-Datapipe-Task-Type': 'ping'} PING_HEADER = {'X-Datapipe-Task-Type': 'ping'}
ENCRYPT_HEADER = {'X-Datapipe-Task-Type': 'encrypt-key'} ENCRYPT_HEADER = {
'X-Datapipe-Task-Type': 'encrypt-key',
'X-Encrypt-Caller-Pod': os.getenv('POD_NAME', ''),
'X-TOS-Path': '',
}
SFCS_STS_HEADER = {'X-Datapipe-Task-Type': 'sfcs-sts'} SFCS_STS_HEADER = {'X-Datapipe-Task-Type': 'sfcs-sts'}
KMS_STS_HEADER = {'X-Datapipe-Task-Type': 'kms-sts'} KMS_STS_HEADER = {'X-Datapipe-Task-Type': 'kms-sts'}
session = requests.Session()
def __init__(self, retry: int = 3, interval: float = 0.5) -> None: # Increment datapipe timeout to make it more robust to real scenarios
def __init__(self, retry: int = 60, interval: float = 2) -> None:
if not os.path.exists(self.DATAPIPE_SOCKET_PATH): if not os.path.exists(self.DATAPIPE_SOCKET_PATH):
raise RuntimeError(f'Datapipe socket {self.DATAPIPE_SOCKET_PATH} does not exist') raise RuntimeError(f'Datapipe socket {self.DATAPIPE_SOCKET_PATH} does not exist')
...@@ -99,8 +105,11 @@ class DataPipeClient: ...@@ -99,8 +105,11 @@ class DataPipeClient:
response = self.session.get(self.url, headers=headers) response = self.session.get(self.url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
logger.warning(
f'call with {headers}, retry: {re}, code: {response.status_code}, body: {response.text}'
)
except Exception as e: except Exception as e:
logger.warning(f'call with {headers} return err: {e}') logger.warning(f'call with {headers}, retry: {re}, raise exception: {e}')
if re > self.retry: if re > self.retry:
break break
...@@ -109,8 +118,11 @@ class DataPipeClient: ...@@ -109,8 +118,11 @@ class DataPipeClient:
return None return None
def get_data_key_iv(self) -> Optional[dict]: def get_data_key_iv(self, path: Optional[str] = None) -> Optional[dict]:
return self._get_retry(self.ENCRYPT_HEADER) header = self.ENCRYPT_HEADER.copy()
if path:
header['X-TOS-Path'] = path
return self._get_retry(header)
def get_sfcs_ak_sk_st(self) -> Optional[dict]: def get_sfcs_ak_sk_st(self) -> Optional[dict]:
return self._get_retry(self.SFCS_STS_HEADER) return self._get_retry(self.SFCS_STS_HEADER)
...@@ -211,7 +223,18 @@ class KmsService: ...@@ -211,7 +223,18 @@ class KmsService:
if uds_proxy: if uds_proxy:
session.mount(f'https://{host}', SnapdAdapter(uds_proxy)) session.mount(f'https://{host}', SnapdAdapter(uds_proxy))
headers['X-Datapipe-Task-Type'] = 'top' headers['X-Datapipe-Task-Type'] = 'top'
re = 0
while True:
try:
resp = session.post(request_url, data=payload, headers=headers) resp = session.post(request_url, data=payload, headers=headers)
if resp.status_code == 200:
return resp.json()
except Exception as e:
logger.warning(f'call kms with header: {headers}, return err: {e}')
if re > 3:
break
sleep(0.5)
re += 1
return resp return resp
def encrypt(self, pt_b64: str) -> str: def encrypt(self, pt_b64: str) -> str:
...@@ -222,7 +245,7 @@ class KmsService: ...@@ -222,7 +245,7 @@ class KmsService:
'KeyName': self._key_name, 'KeyName': self._key_name,
} }
payload = {'Plaintext': pt_b64} payload = {'Plaintext': pt_b64}
resp = KmsService.sigv4( js = KmsService.sigv4(
self._ak, self._ak,
self._sk, self._sk,
self._host, self._host,
...@@ -234,11 +257,9 @@ class KmsService: ...@@ -234,11 +257,9 @@ class KmsService:
self._st, self._st,
self._uds_proxy, self._uds_proxy,
) )
if resp.status_code == 200: if 'Result' in js and 'CiphertextBlob' in js['Result']:
j = resp.json() return js['Result']['CiphertextBlob']
if 'Result' in j: raise RuntimeError(f'kms encrypt failed response: {js}')
return resp.json()['Result']['CiphertextBlob']
raise RuntimeError(f'kms encrypt failed: {resp.text}')
def decrypt(self, ct_b64: str) -> str: def decrypt(self, ct_b64: str) -> str:
params = { params = {
...@@ -248,7 +269,7 @@ class KmsService: ...@@ -248,7 +269,7 @@ class KmsService:
'KeyName': self._key_name, 'KeyName': self._key_name,
} }
payload = {'CiphertextBlob': ct_b64} payload = {'CiphertextBlob': ct_b64}
resp = KmsService.sigv4( js = KmsService.sigv4(
self._ak, self._ak,
self._sk, self._sk,
self._host, self._host,
...@@ -260,12 +281,9 @@ class KmsService: ...@@ -260,12 +281,9 @@ class KmsService:
self._st, self._st,
self._uds_proxy, self._uds_proxy,
) )
if resp.status_code == 200: if 'Result' in js and 'Plaintext' in js['Result']:
j = resp.json() return js['Result']['Plaintext']
if 'Result' in j: raise RuntimeError(f'kms decrypt failed response: {js}')
pt_b64 = resp.json()['Result']['Plaintext']
return pt_b64
raise RuntimeError(f'kms decrypt failed: {resp.text}')
class CipherMode(Enum): class CipherMode(Enum):
...@@ -286,12 +304,13 @@ class CipherInfo: ...@@ -286,12 +304,13 @@ class CipherInfo:
HEADER_SIZE = 262144 HEADER_SIZE = 262144
MAGIC_NUMBER = b'Byte3ncryptM0del' MAGIC_NUMBER = b'Byte3ncryptM0del'
def __init__(self, use_cipher: bool, header_bytes: Optional[bytes] = None) -> None: def __init__(self, use_cipher: bool, header_bytes: Optional[bytes] = None, path: Optional[str] = None) -> None:
self.use_cipher = use_cipher self.use_cipher = use_cipher
self.use_header = False self.use_header = False
self.mode = CipherMode.CTR_128 self.mode = CipherMode.CTR_128
self.key = np.frombuffer(b'\x00' * 16, dtype=np.byte) self.key = np.frombuffer(b'\x00' * 16, dtype=np.byte)
self.iv = np.frombuffer(b'\x00' * 16, dtype=np.byte) self.iv = np.frombuffer(b'\x00' * 16, dtype=np.byte)
self.path = path
if not use_cipher: if not use_cipher:
return return
...@@ -319,7 +338,7 @@ class CipherInfo: ...@@ -319,7 +338,7 @@ class CipherInfo:
# case 2: get key and iv from datapipe uds # case 2: get key and iv from datapipe uds
try: try:
client = DataPipeClient() client = DataPipeClient()
resp = client.get_data_key_iv() resp = client.get_data_key_iv(self.path)
self.key, self.iv = self.convert_key_iv(resp['Key'], resp['IV']) self.key, self.iv = self.convert_key_iv(resp['Key'], resp['IV'])
logger.info('get cipher info from datapipe uds successfully!') logger.info('get cipher info from datapipe uds successfully!')
return return
...@@ -336,9 +355,9 @@ class CipherInfo: ...@@ -336,9 +355,9 @@ class CipherInfo:
except Exception as e: except Exception as e:
logger.warning(f'get cipher info from env failed :{e}') logger.warning(f'get cipher info from env failed :{e}')
# fallback to no cipher # raise error
self.use_cipher = False logger.error('fail to get cipher info in all cases')
logger.warning('fail to get key and iv, fallback to no cipher') raise RuntimeError('fail to get cipher info in all cases')
@staticmethod @staticmethod
def convert_key_iv(key_b64: str, iv_b64: str) -> Tuple[np.ndarray, np.ndarray]: def convert_key_iv(key_b64: str, iv_b64: str) -> Tuple[np.ndarray, np.ndarray]:
...@@ -368,7 +387,7 @@ class CipherInfo: ...@@ -368,7 +387,7 @@ class CipherInfo:
ak = resp['Cred']['AccessKeyId'] ak = resp['Cred']['AccessKeyId']
sk = resp['Cred']['SecretAccessKey'] sk = resp['Cred']['SecretAccessKey']
st = resp['Cred']['SessionToken'] st = resp['Cred']['SessionToken']
logger.info('get kms credential from datapipe successfully!') logger.info('get kms ak/sk/st from datapipe successfully!')
except Exception as e: except Exception as e:
logger.warning(f'get kms ak/sk/st from datapipe failed: {e}') logger.warning(f'get kms ak/sk/st from datapipe failed: {e}')
...@@ -391,8 +410,8 @@ class CipherInfo: ...@@ -391,8 +410,8 @@ class CipherInfo:
return header_bytes return header_bytes
def create_cipher_with_header(mode: CipherMode) -> CipherInfo: def create_cipher_with_header(mode: CipherMode, path: str) -> CipherInfo:
c = CipherInfo(False) c = CipherInfo(False, None, path)
c.use_cipher = True c.use_cipher = True
c.use_header = True c.use_header = True
c.mode = mode c.mode = mode
...@@ -407,6 +426,7 @@ def create_cipher_with_header(mode: CipherMode) -> CipherInfo: ...@@ -407,6 +426,7 @@ def create_cipher_with_header(mode: CipherMode) -> CipherInfo:
def encrypt(cipher_info: CipherInfo, pt: np.ndarray, ct: np.ndarray, offset: int): def encrypt(cipher_info: CipherInfo, pt: np.ndarray, ct: np.ndarray, offset: int):
# note: dtype of pt and ct should be np.uint8
if not cipher_info.use_cipher: if not cipher_info.use_cipher:
logger.warning('cipher.encrypt: use_cipher False, skip') logger.warning('cipher.encrypt: use_cipher False, skip')
return return
...@@ -417,6 +437,7 @@ def encrypt(cipher_info: CipherInfo, pt: np.ndarray, ct: np.ndarray, offset: int ...@@ -417,6 +437,7 @@ def encrypt(cipher_info: CipherInfo, pt: np.ndarray, ct: np.ndarray, offset: int
def decrypt(cipher_info: CipherInfo, ct: np.ndarray, pt: np.ndarray, offset: int): def decrypt(cipher_info: CipherInfo, ct: np.ndarray, pt: np.ndarray, offset: int):
# note: dtype of pt and ct should be np.uint8
if not cipher_info.use_cipher: if not cipher_info.use_cipher:
logger.warning('cipher.decrypt: use_cipher False, skip') logger.warning('cipher.decrypt: use_cipher False, skip')
return return
......
MLP_SECRET_KEY_FILENAME = "MLP_SECRET_KEY"
MLP_ACCESS_KEY_FILENAME = "MLP_ACCESS_KEY"
SFCS_DEFAULT_CONFIG_PATH_ENV = "SFCS_METAINFO_PATH"
SFCS_DEFAULT_METAINFO_PATH = "/root/.volc/SFCSConfiguration.json"
RDMA_NIC_ENV = "MLP_RDMA_NIC_NAMES"
DEFAULT_NIC_NAME = "eth0"
RDMA_SEGMENT_ENV = "MLP_RDMA_NETWORK_SEGMENT"
DEFAULT_CREDENTIAL_PATH_ENV = "CREDENTIAL_PATH"
DEFAULT_CREDENTIAL_PATH = "/mlplatform/.credential/"
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef _CLOUDFS_LIBCFS3_CLIENT_CFS_H_ #ifndef _CLOUDFS_LIBCFS3_CLIENT_CLOUDFS_H_
#define _CLOUDFS_LIBCFS3_CLIENT_CFS_H_ #define _CLOUDFS_LIBCFS3_CLIENT_CLOUDFS_H_
#include <errno.h> /* for EINTERNAL, etc. */ #include <errno.h> /* for EINTERNAL, etc. */
#include <fcntl.h> /* for O_RDONLY, O_WRONLY */ #include <fcntl.h> /* for O_RDONLY, O_WRONLY */
...@@ -50,7 +50,7 @@ extern "C" ...@@ -50,7 +50,7 @@ extern "C"
{ {
#endif #endif
/** /**
* Some utility decls used in libcfs. * Some utility decls used in libcloudfs.
*/ */
typedef int32_t tSize; /// size of data for read/write io ops typedef int32_t tSize; /// size of data for read/write io ops
typedef time_t tTime; /// time type in seconds typedef time_t tTime; /// time type in seconds
...@@ -613,9 +613,10 @@ extern "C" ...@@ -613,9 +613,10 @@ extern "C"
* @param fs The configured filesystem handle. * @param fs The configured filesystem handle.
* @param trg The path of target (resulting) file * @param trg The path of target (resulting) file
* @param scrs A list of paths to source files * @param scrs A list of paths to source files
* @param srcsNum Number of source paths
* @return Returns 0 on success, -1 on error. * @return Returns 0 on success, -1 on error.
*/ */
int cfsConcat(cfsFS fs, const char *trg, const char **srcs); int cfsConcat(cfsFS fs, const char *trg, const char **srcs, int srcsNum);
/** /**
* cfsGetWorkingDirectory - Get the current working directory for * cfsGetWorkingDirectory - Get the current working directory for
...@@ -797,6 +798,25 @@ extern "C" ...@@ -797,6 +798,25 @@ extern "C"
*/ */
void cfsFreeEncryptionZoneInfo(cfsEncryptionZoneInfo *infos, int numEntries); void cfsFreeEncryptionZoneInfo(cfsEncryptionZoneInfo *infos, int numEntries);
/**
* cfsFileSystemInfo - Information about a file system
*/
typedef struct
{
int64_t blockSize;
int64_t capacity;
int64_t remaining;
int64_t inodeCapacity;
int64_t inodeRemaining;
} cfsFileSytemInfo;
/**
* cfsGetFileSystemInfo - Get cloudfs filesystem information
* @param fs The configured filesystem handle.
* @return filesystem file info
*/
cfsFileSytemInfo cfsGetFileSystemInfo(cfsFS fs);
/** /**
* cfsGetHosts - Get hostnames where a particular block (determined by * cfsGetHosts - Get hostnames where a particular block (determined by
* pos & blocksize) of a file is stored. The last element in the array * pos & blocksize) of a file is stored. The last element in the array
...@@ -947,7 +967,7 @@ extern "C" ...@@ -947,7 +967,7 @@ extern "C"
* cfsGetHANamenodes - If cfs is configured with HA namenode, return all namenode informations as an array. * cfsGetHANamenodes - If cfs is configured with HA namenode, return all namenode informations as an array.
* Else return NULL. * Else return NULL.
* *
* Using configure file which is given by environment parameter LIBCFS_CONF * Using configure file which is given by environment parameter LIBCLOUDFS_CONF
* or "cloudfs.xml" in working directory. * or "cloudfs.xml" in working directory.
* *
* @param nameservice cfs name service id. * @param nameservice cfs name service id.
...@@ -1184,8 +1204,50 @@ extern "C" ...@@ -1184,8 +1204,50 @@ extern "C"
*/ */
void cfsCancelJob(cfsFS fs, const char *job_id); void cfsCancelJob(cfsFS fs, const char *job_id);
typedef struct cfsReplicaPolicy
{
char *dcNames; // "name whit ',' split"
bool distributed;
bool randomMajority;
int32_t localSwitchTarget;
int32_t otherSwitchTarget;
} cfsReplicaPolicy;
typedef struct cfsReadPolicy
{
bool localDcOnly;
char *dcNames; // Name of Datacenter, split by ','
} cfsReadPolicy;
typedef struct cfsUploadPolicy
{
int32_t uploadIntervalMs;
} cfsUploadPolicy;
typedef struct cfsPolicyResponse
{
char *path;
cfsReplicaPolicy *replicaPolicy;
cfsReadPolicy *readPolicy;
cfsUploadPolicy *uploadPolicy;
} cfsPolicyResponse;
int cfsSetPolicy(cfsFS fs, const char *path, cfsReplicaPolicy *replicaPolicy, cfsReadPolicy *readPolicy,
cfsUploadPolicy *uploadPolicy);
cfsPolicyResponse *cfsRemovePolicy(cfsFS fs, const char *path, bool removeReplicaPolicy, bool removeReadpolicy,
bool removeUploadPolicy);
cfsPolicyResponse *cfsGetPolicy(cfsFS fs, const char *path, bool isReplicaPolicy, bool isReadpolicy,
bool isUploadPolicy);
cfsPolicyResponse *cfsListPolicy(cfsFS fs, int *numEntries, bool isReplicaPolicy, bool isReadpolicy,
bool isUploadPolicy);
void cfsFreePolicyResponse(cfsPolicyResponse *resp, int numEntries);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif /* _CLOUDFS_LIBCFS3_CLIENT_CFS_H_ */ #endif /* _CLOUDFS_LIBCFS3_CLIENT_CLOUDFS_H_ */
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <torch/extension.h> #include <torch/extension.h>
#if defined(USE_CUDA)
#include <cuda_runtime.h> #include <cuda_runtime.h>
#endif
#include <fcntl.h> #include <fcntl.h>
#include <unistd.h> #include <unistd.h>
#include <thread> #include <thread>
...@@ -30,7 +32,7 @@ ...@@ -30,7 +32,7 @@
#include "sfcs.h" #include "sfcs.h"
#define THREAD_NICE_ADJ -10 #define THREAD_NICE_ADJ -10
#define BUF_ALIGN_SIZE 4096 #define BUF_ALIGN_SIZE (size_t)4096
using namespace std; using namespace std;
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#ifndef IO_HELPER_H #ifndef IO_HELPER_H
#define IO_HELPER_H #define IO_HELPER_H
#include "load_utils.h" #include "posix.h"
#include "sfcs.h"
class IOHelper class IOHelper
{ {
...@@ -27,12 +28,28 @@ class IOHelper ...@@ -27,12 +28,28 @@ class IOHelper
public: public:
~IOHelper(); ~IOHelper();
void load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, torch::Tensor sample_tensor, void load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, size_t length, int64_t offset,
int64_t offset, int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk, int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr, bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size); pybind11::array_t<char> iv_arr, int64_t header_size);
void save_tensor_to_file(torch::Tensor tensor, std::string file_path, size_t length, bool use_pinmem,
bool use_sfcs_sdk, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size);
void save_tensor_to_file_cpu(torch::Tensor tensor, std::string file_path, size_t length, bool use_pinmem,
bool use_sfcs_sdk, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size);
void init_buffer(string file_path, int64_t file_size, bool use_pinmem, bool use_sfcs_sdk); void init_buffer(string file_path, int64_t file_size, bool use_pinmem, bool use_sfcs_sdk);
void free_buffer(); void free_buffer();
}; };
size_t get_file_size(const char *file_name, bool use_sfcs_sdk);
void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info);
void load_file_to_tensor_cpu(std::string file_path, torch::Tensor res_tensor, size_t length, int64_t offset,
int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size);
#endif #endif
...@@ -19,8 +19,28 @@ ...@@ -19,8 +19,28 @@
#include "common.h" #include "common.h"
#include "cipher.h" #include "cipher.h"
void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size, class POSIXFile
size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info); {
size_t get_file_size(const char *file_name, bool use_sfcs_sdk); public:
std::string file_path;
// cipher related
CipherInfo cipher_info;
POSIXFile(std::string file_path);
POSIXFile(std::string file_path, CipherInfo cipher_info);
POSIXFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr,
size_t header_size);
size_t read_file_to_address_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset, bool use_direct_io);
size_t read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread,
bool use_direct_io);
size_t write_file_from_addr(char *addr, size_t length, bool append);
private:
void read_file_to_address_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset, bool use_direct_io,
CipherInfo cipher_info);
};
#endif #endif
\ No newline at end of file
...@@ -29,31 +29,62 @@ ...@@ -29,31 +29,62 @@
using namespace std; using namespace std;
class SFCSFs
{
public:
cfsFS fs;
SFCSFs();
~SFCSFs();
void concat_files(std::string file_name, vector<const char *> file_paths);
void rename_file(const char *file_path, const char *file_name);
void mkdir(std::string file_path);
int64_t get_block_size();
size_t read_file_to_addr(std::string file_name, CipherInfo cipher_info, char *addr, size_t length, size_t offset);
size_t write_file_from_addr(std::string file_name, CipherInfo cipher_info, char *addr, size_t length,
size_t offset);
void read_multi_files(pybind11::list file_paths, pybind11::list tensors, pybind11::list lengths,
pybind11::list offsets, int num_thread, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, size_t header_size);
void write_multi_files(pybind11::list file_paths, pybind11::list tensors, pybind11::list lengths,
pybind11::list offsets, int num_thread, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, size_t header_size);
void get_file_size(std::string file_name, size_t *size);
void get_multi_file_size(pybind11::list file_paths, pybind11::list sizes, int num_thread);
};
class SFCSFile class SFCSFile
{ {
public: public:
cfsFS fs; cfsFS fs;
bool fs_owner;
SFCSFs *sfcs_fs;
std::string file_path; std::string file_path;
// cipher related // cipher related
CipherInfo cipher_info; CipherInfo cipher_info;
SFCSFile(std::string file_path); SFCSFile(std::string file_path);
SFCSFile(std::string path, SFCSFs *sfcs_fs);
SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr, SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr,
size_t header_size); size_t header_size);
SFCSFile(std::string file_path, CipherInfo cipher_info); SFCSFile(std::string file_path, CipherInfo cipher_info);
SFCSFile(std::string file_path, SFCSFs *sfcs_fs, CipherInfo cipher_info);
~SFCSFile(); ~SFCSFile();
size_t get_file_size(); size_t get_file_size();
size_t read_file_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size, size_t read_file_to_address_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset); size_t global_offset);
size_t read_file_to_addr(char *addr, size_t length, size_t offset);
size_t read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread); size_t read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread);
size_t write_file_from_array(pybind11::array_t<char> arr, size_t length); size_t write_file_from_array(pybind11::array_t<char> arr, size_t length, bool append);
size_t write_file_from_tensors(pybind11::list tensors, pybind11::list sizes, pybind11::list offsets,
std::string concat_dir, std::string concat_file);
size_t write_file_from_addr(char *addr, size_t length, size_t offset, bool append);
void delete_file(); void delete_file();
private: private:
size_t read_file(char *addr, size_t length, size_t offset); void read_file_to_address_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size,
void read_file_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size, size_t total_size, size_t total_size, size_t global_offset);
size_t global_offset); void write_file_from_tensor(torch::Tensor tensor, size_t length, size_t offset, std::string file_name);
size_t write_file(char *addr, size_t length);
}; };
#endif #endif
\ No newline at end of file
...@@ -60,31 +60,15 @@ void IOHelper::free_buffer() ...@@ -60,31 +60,15 @@ void IOHelper::free_buffer()
} }
} }
void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_t *offset, int64_t device_id, void read_unaligned_part_gpu(std::string file_path, torch::Tensor res_tensor, int64_t *offset, int64_t device_id,
size_t *total_size, bool use_sfcs_sdk, bool use_direct_io, size_t *read_unaligned_size, size_t *total_size, bool use_sfcs_sdk, bool use_direct_io, size_t *read_unaligned_size,
CipherInfo cipher_info) CipherInfo cipher_info)
{ {
// cpu align only read head part, while gpu align read both head and tail part // cpu align only read head part, while gpu align read both head and tail part
if (device_id < 0) if (device_id < 0)
{ {
// head is aligned throw std::runtime_error("read_unaligned_part_gpu only support gpu device");
if ((*offset & (BUF_ALIGN_SIZE - 1)) == 0)
{
return;
}
*read_unaligned_size = min(BUF_ALIGN_SIZE - (*offset & (BUF_ALIGN_SIZE - 1)), *total_size);
if ((uint64_t)res_tensor.data_ptr() % BUF_ALIGN_SIZE != *offset % BUF_ALIGN_SIZE)
{
throw std::runtime_error("data ptr does not satisfy the align purpose");
}
read_file(file_path, (char *)res_tensor.data_ptr(), device_id, NULL, 1, *read_unaligned_size, *offset,
use_sfcs_sdk, use_direct_io, cipher_info);
*total_size -= *read_unaligned_size;
*offset += *read_unaligned_size;
} }
else
{
size_t end_offset = *offset + *total_size; size_t end_offset = *offset + *total_size;
// both head and tail are aligned // both head and tail are aligned
if ((*offset & (BUF_ALIGN_SIZE - 1)) == 0 && ((end_offset) & (BUF_ALIGN_SIZE - 1)) == 0) if ((*offset & (BUF_ALIGN_SIZE - 1)) == 0 && ((end_offset) & (BUF_ALIGN_SIZE - 1)) == 0)
...@@ -93,8 +77,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_ ...@@ -93,8 +77,8 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
} }
char tmp_buf_head[BUF_ALIGN_SIZE] = {}; char tmp_buf_head[BUF_ALIGN_SIZE] = {};
char tmp_buf_tail[BUF_ALIGN_SIZE] = {}; char tmp_buf_tail[BUF_ALIGN_SIZE] = {};
cudaSetDevice(device_id);
// read head unaligned // read head unaligned
cudaSetDevice(device_id);
if ((*offset & (BUF_ALIGN_SIZE - 1)) != 0) if ((*offset & (BUF_ALIGN_SIZE - 1)) != 0)
{ {
size_t read_head_size = min(BUF_ALIGN_SIZE - (*offset & (BUF_ALIGN_SIZE - 1)), *total_size); size_t read_head_size = min(BUF_ALIGN_SIZE - (*offset & (BUF_ALIGN_SIZE - 1)), *total_size);
...@@ -114,29 +98,27 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_ ...@@ -114,29 +98,27 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
*total_size -= end_offset - tail_offset; *total_size -= end_offset - tail_offset;
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
}
} }
void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, torch::Tensor sample_tensor, void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, size_t length, int64_t offset,
int64_t offset, int64_t device_id, int64_t num_thread, bool use_pinmem, int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_sfcs_sdk, bool use_direct_io, bool use_cipher, bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr, int64_t header_size) pybind11::array_t<char> iv_arr, int64_t header_size)
{ {
size_t file_size = get_file_size(file_path.c_str(), use_sfcs_sdk); size_t file_size = get_file_size(file_path.c_str(), use_sfcs_sdk);
size_t total_size = file_size - offset;
size_t read_unaligned_size = 0; size_t read_unaligned_size = 0;
size_t total_size = length > 0 ? length : file_size - offset;
// set cipher // set cipher
CipherInfo cipher_info(use_cipher, key_arr, iv_arr, header_size); CipherInfo cipher_info(use_cipher, key_arr, iv_arr, header_size);
if (device_id < 0) if (device_id < 0)
{ {
read_unaligned_part(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io,
&read_unaligned_size, cipher_info);
read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, device_id, NULL, num_thread, read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, device_id, NULL, num_thread,
total_size, offset, use_sfcs_sdk, use_direct_io, cipher_info); total_size, offset, use_sfcs_sdk, use_direct_io, cipher_info);
} }
else else
{ {
read_unaligned_part(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io, // read unaligned part first, since GPU can only decrypt data in integral multiple of 16 Bytes
read_unaligned_part_gpu(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io,
&read_unaligned_size, cipher_info); &read_unaligned_size, cipher_info);
// change use_pinmem attribute may introduce ambiguity // change use_pinmem attribute may introduce ambiguity
...@@ -169,8 +151,12 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens ...@@ -169,8 +151,12 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
iv[i] = cipher_info.iv[i]; iv[i] = cipher_info.iv[i];
} }
counter_inc_by(iv, AES_BLOCK_SIZE, (offset - cipher_info.header_size) / AES_BLOCK_SIZE); counter_inc_by(iv, AES_BLOCK_SIZE, (offset - cipher_info.header_size) / AES_BLOCK_SIZE);
unsigned char *iv_gpu; unsigned char *iv_gpu = NULL;
cudaMalloc((void **)&iv_gpu, AES_BLOCK_SIZE); cudaMalloc((void **)&iv_gpu, AES_BLOCK_SIZE);
if (iv_gpu == NULL)
{
throw std::runtime_error("iv_gpu cannot be allocated");
}
cudaMemcpy(iv_gpu, iv, AES_BLOCK_SIZE, cudaMemcpyHostToDevice); cudaMemcpy(iv_gpu, iv, AES_BLOCK_SIZE, cudaMemcpyHostToDevice);
unsigned char *ct = reinterpret_cast<unsigned char *>(res_tensor.data_ptr()) + read_unaligned_size; unsigned char *ct = reinterpret_cast<unsigned char *>(res_tensor.data_ptr()) + read_unaligned_size;
int cipher_ret = ctr_decrypt_gpu(cipher_info.mode, cipher_info.key, iv_gpu, ct, total_size, ct); int cipher_ret = ctr_decrypt_gpu(cipher_info.mode, cipher_info.key, iv_gpu, ct, total_size, ct);
...@@ -183,3 +169,52 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens ...@@ -183,3 +169,52 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
} }
} }
} }
void IOHelper::save_tensor_to_file(torch::Tensor tensor, std::string file_path, size_t length, bool use_pinmem,
bool use_sfcs_sdk, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
char *buf;
CipherInfo cipher_info(use_cipher, key_arr, iv_arr, header_size);
if (tensor.device().is_cuda() || use_cipher)
{
// change use_pinmem attribute may introduce ambiguity
if (buffer_size_ > 0 && use_pinmem != use_pinmem_)
{
throw std::runtime_error("use_pinmem attribute of an exising IOHelper should not be changed");
}
if (pin_mem == NULL || length > buffer_size_)
{
init_buffer(file_path, length, use_pinmem, use_sfcs_sdk);
}
buf = pin_mem;
if (tensor.device().is_cuda())
{
cudaSetDevice(tensor.device().index());
cudaMemcpyAsync(buf, (char *)tensor.data_ptr(), length, cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
}
else
{
memcpy(buf, (char *)tensor.data_ptr(), length);
}
}
else
{
buf = (char *)tensor.data_ptr();
}
if (use_sfcs_sdk)
{
SFCSFile sfcs_file(file_path, cipher_info);
sfcs_file.write_file_from_addr(buf, length, 0, true);
}
else
{
POSIXFile posix_file(file_path, cipher_info);
posix_file.write_file_from_addr(buf, length, true);
}
}
#include "include/io_helper.h"
#include "include/cipher.h"
IOHelper::~IOHelper()
{
}
// init buffer with given positive size or the size of the file in specified
// path
void IOHelper::init_buffer(string file_path, int64_t buffer_size, bool use_pinmem, bool use_sfcs_sdk)
{
}
void IOHelper::free_buffer()
{
}
void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, size_t length, int64_t offset,
int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
load_file_to_tensor_cpu(file_path, res_tensor, length, offset, device_id, num_thread, use_pinmem, use_sfcs_sdk,
use_direct_io, use_cipher, key_arr, iv_arr, header_size);
}
void IOHelper::save_tensor_to_file(torch::Tensor tensor, std::string file_path, size_t length, bool use_pinmem,
bool use_sfcs_sdk, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
save_tensor_to_file_cpu(tensor, file_path, length, use_pinmem, use_sfcs_sdk, use_cipher, key_arr, iv_arr,
header_size);
}
#include "include/io_helper.h"
#include "include/cipher.h"
size_t get_file_size(const char *file_name, bool use_sfcs_sdk)
{
if (use_sfcs_sdk)
{
SFCSFile sfcs_file(file_name);
return sfcs_file.get_file_size();
}
else
{
struct stat st;
stat(file_name, &st);
return st.st_size;
}
}
void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size,
size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info)
{
if (total_size == 0)
{
return;
}
if (use_sfcs_sdk)
{
SFCSFile sfcs_file(file_path, cipher_info);
sfcs_file.read_file_to_address_parallel(addr, device_id, dev_mem, num_thread, total_size, global_offset);
}
else
{
POSIXFile posix_file(file_path, cipher_info);
posix_file.read_file_to_address_parallel(addr, device_id, dev_mem, num_thread, total_size, global_offset,
use_direct_io);
}
}
void load_file_to_tensor_cpu(std::string file_path, torch::Tensor res_tensor, size_t length, int64_t offset,
int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
size_t file_size = get_file_size(file_path.c_str(), use_sfcs_sdk);
size_t read_unaligned_size = 0;
size_t total_size = length > 0 ? length : file_size - offset;
// set cipher
CipherInfo cipher_info(use_cipher, key_arr, iv_arr, header_size);
if (device_id < 0)
{
read_file(file_path, (char *)res_tensor.data_ptr() + read_unaligned_size, device_id, NULL, num_thread,
total_size, offset, use_sfcs_sdk, use_direct_io, cipher_info);
}
}
void IOHelper::save_tensor_to_file_cpu(torch::Tensor tensor, std::string file_path, size_t length, bool use_pinmem,
bool use_sfcs_sdk, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
char *buf;
CipherInfo cipher_info(use_cipher, key_arr, iv_arr, header_size);
if (use_cipher)
{
// change use_pinmem attribute may introduce ambiguity
if (buffer_size_ > 0 && use_pinmem != use_pinmem_)
{
throw std::runtime_error("use_pinmem attribute of an exising IOHelper should not be changed");
}
if (pin_mem == NULL || length > buffer_size_)
{
init_buffer(file_path, length, use_pinmem, use_sfcs_sdk);
}
buf = pin_mem;
memcpy(buf, (char *)tensor.data_ptr(), length);
}
else
{
buf = (char *)tensor.data_ptr();
}
if (use_sfcs_sdk)
{
SFCSFile sfcs_file(file_path, cipher_info);
sfcs_file.write_file_from_addr(buf, length, 0, true);
}
else
{
POSIXFile posix_file(file_path, cipher_info);
posix_file.write_file_from_addr(buf, length, true);
}
}
/*
* Copyright (c) 2024 Beijing Volcano Engine Technology Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/io_helper.h"
#include "include/cipher.h"
IOHelper::~IOHelper()
{
}
// init buffer with given positive size or the size of the file in specified
// path
void IOHelper::init_buffer(string file_path, int64_t buffer_size, bool use_pinmem, bool use_sfcs_sdk)
{
}
void IOHelper::free_buffer()
{
}
void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, size_t length, int64_t offset,
int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
load_file_to_tensor_cpu(file_path, res_tensor, length, offset, device_id, num_thread, use_pinmem, use_sfcs_sdk,
use_direct_io, use_cipher, key_arr, iv_arr, header_size);
}
void IOHelper::save_tensor_to_file(torch::Tensor tensor, std::string file_path, size_t length, bool use_pinmem,
bool use_sfcs_sdk, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, int64_t header_size)
{
save_tensor_to_file_cpu(tensor, file_path, length, use_pinmem, use_sfcs_sdk, use_cipher, key_arr, iv_arr,
header_size);
}
...@@ -13,14 +13,32 @@ ...@@ -13,14 +13,32 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "include/load_utils.h" #include "include/posix.h"
#include "include/logging.h" #include "include/logging.h"
#include "include/cipher.h" #include "include/cipher.h"
#include "include/fastcrypto.h" #include "include/fastcrypto.h"
#include <errno.h> #include <errno.h>
void read_file_thread_fread(int thread_id, string file_path, char *addr, int device_id, char *dev_mem, POSIXFile::POSIXFile(std::string file_path)
size_t block_size, size_t total_size, size_t global_offset, bool use_direct_io, {
this->file_path = file_path;
}
POSIXFile::POSIXFile(std::string file_path, CipherInfo cipher_info)
{
this->file_path = file_path;
this->cipher_info = cipher_info;
}
POSIXFile::POSIXFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, size_t header_size)
: POSIXFile(file_path)
{
this->cipher_info = CipherInfo(use_cipher, key_arr, iv_arr, header_size);
}
void POSIXFile::read_file_to_address_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset, bool use_direct_io,
CipherInfo cipher_info) CipherInfo cipher_info)
{ {
size_t offset = thread_id * block_size; size_t offset = thread_id * block_size;
...@@ -50,6 +68,7 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, int dev ...@@ -50,6 +68,7 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, int dev
} }
} }
} }
if (fd == -1) if (fd == -1)
{ {
if ((fd = open(file_path.c_str(), O_RDONLY)) < 0) if ((fd = open(file_path.c_str(), O_RDONLY)) < 0)
...@@ -58,21 +77,25 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, int dev ...@@ -58,21 +77,25 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, int dev
throw std::runtime_error("veTurboIO Exception: can't apply open operation"); throw std::runtime_error("veTurboIO Exception: can't apply open operation");
} }
} }
FILE *fp = fdopen(fd, "rb"); FILE *fp = fdopen(fd, "rb");
if (fp == NULL) if (fp == NULL)
{ {
logError("can't apply fdopen to file", file_path.c_str(), std::strerror(errno)); logError("can't apply fdopen to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fdopen operation"); throw std::runtime_error("veTurboIO Exception: can't apply fdopen operation");
} }
if ((ret = fseek(fp, global_offset + offset, SEEK_SET)) < 0) if ((ret = fseek(fp, global_offset + offset, SEEK_SET)) < 0)
{ {
logError("can't apply fseek to file", file_path.c_str(), std::strerror(errno)); logError("can't apply fseek to file", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: can't apply fseek operation"); throw std::runtime_error("veTurboIO Exception: can't apply fseek operation");
} }
if ((size_read = fread(addr + offset, 1, read_size, fp)) == 0) if ((size_read = fread(addr + offset, 1, read_size, fp)) == 0)
{ {
logWarn("read file with 0 bytes returned", file_path.c_str(), offset, read_size); logWarn("read file with 0 bytes returned", file_path.c_str(), offset, read_size);
} }
if ((ret = fclose(fp)) < 0) if ((ret = fclose(fp)) < 0)
{ {
logError("can't apply fclose to file", file_path.c_str(), std::strerror(errno)); logError("can't apply fclose to file", file_path.c_str(), std::strerror(errno));
...@@ -92,21 +115,20 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, int dev ...@@ -92,21 +115,20 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, int dev
} }
} }
#if defined(USE_CUDA)
if (dev_mem != NULL && device_id >= 0) if (dev_mem != NULL && device_id >= 0)
{ {
cudaSetDevice(device_id); cudaSetDevice(device_id);
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice); cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
} }
#elif defined(USE_NPU)
#else
#endif
} }
void read_file(string file_path, char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size, size_t POSIXFile::read_file_to_address_parallel(char *addr, int device_id, char *dev_mem, int num_thread,
size_t global_offset, bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info) size_t total_size, size_t global_offset, bool use_direct_io)
{ {
if (total_size == 0)
{
return;
}
vector<thread> threads(num_thread); vector<thread> threads(num_thread);
size_t block_size = (size_t)ceil((double)total_size / num_thread); size_t block_size = (size_t)ceil((double)total_size / num_thread);
...@@ -115,37 +137,77 @@ void read_file(string file_path, char *addr, int device_id, char *dev_mem, int n ...@@ -115,37 +137,77 @@ void read_file(string file_path, char *addr, int device_id, char *dev_mem, int n
// re-caculate the real needed thread num; // re-caculate the real needed thread num;
num_thread = (total_size + block_size - 1) / block_size; num_thread = (total_size + block_size - 1) / block_size;
if (use_sfcs_sdk)
{
SFCSFile sfcs_file(file_path, cipher_info);
sfcs_file.read_file_parallel(addr, device_id, dev_mem, num_thread, total_size, global_offset);
}
else
{
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
{ {
threads[thread_id] = std::thread(read_file_thread_fread, thread_id, file_path, addr, device_id, dev_mem, threads[thread_id] = std::thread(&POSIXFile::read_file_to_address_thread, this, thread_id, addr, device_id,
block_size, total_size, global_offset, use_direct_io, cipher_info); dev_mem, block_size, total_size, global_offset, use_direct_io, cipher_info);
} }
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
{ {
threads[thread_id].join(); threads[thread_id].join();
} }
}
return total_size;
}
size_t POSIXFile::read_file_to_array(pybind11::array_t<char> arr, size_t length, size_t offset, int num_thread,
bool use_direct_io)
{
pybind11::buffer_info buf_info = arr.request();
char *addr = static_cast<char *>(buf_info.ptr);
madvise(addr, length, MADV_HUGEPAGE);
return read_file_to_address_parallel(addr, -1, NULL, num_thread, length, offset, use_direct_io);
} }
size_t get_file_size(const char *file_name, bool use_sfcs_sdk) size_t POSIXFile::write_file_from_addr(char *addr, size_t length, bool append)
{ {
if (use_sfcs_sdk) int fd;
int flags = O_WRONLY;
size_t ret;
size_t count;
char *src = addr;
size_t offset = 0;
if (append)
{ {
SFCSFile sfcs_file(file_name); struct stat st;
return sfcs_file.get_file_size(); stat(file_path.c_str(), &st);
offset = st.st_size;
flags |= O_APPEND;
} }
else
if (cipher_info.use_cipher)
{ {
struct stat st; size_t h_off = cipher_info.header_size;
stat(file_name, &st); CtrEncrypter enc(cipher_info.mode, cipher_info.key, cipher_info.iv, offset - h_off);
return st.st_size; unsigned char *pt = reinterpret_cast<unsigned char *>(addr);
int cipher_ret = enc.encrypt_update(pt, length, pt);
if (!cipher_ret)
{
throw std::runtime_error("Cipher Exception: encrypt fail");
}
}
fd = open(file_path.c_str(), flags);
if (fd < 0)
{
logError("open failed", file_path.c_str(), std::strerror(errno));
throw std::runtime_error("veTurboIO Exception: open failed");
}
count = length;
while (count > 0)
{
ret = write(fd, src, count);
if (ret < 0)
{
logError("Failed to write file", file_path.c_str());
throw std::runtime_error("veTurboIO Exception: write file");
}
count -= ret;
src += ret;
} }
close(fd);
return length;
} }
...@@ -19,7 +19,22 @@ ...@@ -19,7 +19,22 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
py::class_<IOHelper>(m, "IOHelper").def(py::init<>()).def("load_file_to_tensor", &IOHelper::load_file_to_tensor); py::class_<IOHelper>(m, "IOHelper")
.def(py::init<>())
.def("load_file_to_tensor", &IOHelper::load_file_to_tensor)
.def("save_tensor_to_file", &IOHelper::save_tensor_to_file);
py::class_<POSIXFile>(m, "POSIXFile")
.def(py::init<std::string>())
.def(py::init<std::string, bool, pybind11::array_t<char>, pybind11::array_t<char>, size_t>())
.def("read_file_to_array", &POSIXFile::read_file_to_array);
py::class_<SFCSFs>(m, "SFCSFs")
.def(py::init<>())
.def("mkdir", &SFCSFs::mkdir)
.def("read_multi_files", &SFCSFs::read_multi_files)
.def("write_multi_files", &SFCSFs::write_multi_files)
.def("get_multi_file_size", &SFCSFs::get_multi_file_size);
py::class_<SFCSFile>(m, "SFCSFile") py::class_<SFCSFile>(m, "SFCSFile")
.def(py::init<std::string>()) .def(py::init<std::string>())
...@@ -27,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -27,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("get_file_size", &SFCSFile::get_file_size) .def("get_file_size", &SFCSFile::get_file_size)
.def("read_file_to_array", &SFCSFile::read_file_to_array) .def("read_file_to_array", &SFCSFile::read_file_to_array)
.def("write_file_from_array", &SFCSFile::write_file_from_array) .def("write_file_from_array", &SFCSFile::write_file_from_array)
.def("write_file_from_tensors", &SFCSFile::write_file_from_tensors)
.def("delete_file", &SFCSFile::delete_file); .def("delete_file", &SFCSFile::delete_file);
py::class_<CtrEncWrap>(m, "CtrEncWrap") py::class_<CtrEncWrap>(m, "CtrEncWrap")
......
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
#include "include/cipher.h" #include "include/cipher.h"
#include "include/fastcrypto.h" #include "include/fastcrypto.h"
SFCSFile::SFCSFile(std::string path) SFCSFs::SFCSFs()
{ {
file_path = path;
// construct builder // construct builder
struct cfsBuilder *bld = cfsNewBuilder(); struct cfsBuilder *bld = cfsNewBuilder();
if (bld == NULL) if (bld == NULL)
...@@ -42,11 +40,183 @@ SFCSFile::SFCSFile(std::string path) ...@@ -42,11 +40,183 @@ SFCSFile::SFCSFile(std::string path)
} }
} }
SFCSFs::~SFCSFs()
{
cfsDisconnect(fs);
}
void SFCSFs::concat_files(std::string file_name, vector<const char *> file_paths)
{
int ret;
ret = cfsConcat(fs, file_name.c_str(), &file_paths[0], file_paths.size());
if (ret == -1)
{
logError("Failed to concat files", cfsGetLastError());
throw std::runtime_error("SFCS Exception: concat files");
}
}
void SFCSFs::rename_file(const char *file_path, const char *file_name)
{
int ret;
ret = cfsRename2(fs, file_path, file_name);
if (ret == -1)
{
logError("Failed to rename file", file_path, cfsGetLastError());
throw std::runtime_error("SFCS Exception: rename file");
}
}
int64_t SFCSFs::get_block_size()
{
int64_t ret;
ret = cfsGetDefaultBlockSize(fs);
if (ret == -1)
{
logError("Failed to get default block size", cfsGetLastError());
throw std::runtime_error("SFCS Exception: get block size");
}
return ret;
}
void SFCSFs::mkdir(std::string file_path)
{
int ret;
ret = cfsCreateDirectory(fs, file_path.c_str());
if (ret == -1)
{
logError("Failed to create dir", file_path, cfsGetLastError());
throw std::runtime_error("SFCS Exception: create dir");
}
}
size_t SFCSFs::read_file_to_addr(std::string file_name, CipherInfo cipher_info, char *addr, size_t length,
size_t offset)
{
SFCSFile sfcs_file(file_name, this, cipher_info);
return sfcs_file.read_file_to_addr(addr, length, offset);
}
void SFCSFs::read_multi_files(pybind11::list file_paths, pybind11::list tensors, pybind11::list lengths,
pybind11::list offsets, int num_thread, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, size_t header_size)
{
vector<thread> threads(num_thread);
auto file_names = file_paths.cast<std::vector<std::string>>();
auto tensors_vector = tensors.cast<std::vector<torch::Tensor>>();
auto lengths_vector = lengths.cast<std::vector<size_t>>();
auto offsets_vector = offsets.cast<std::vector<size_t>>();
CipherInfo cipher_info = CipherInfo(use_cipher, key_arr, iv_arr, header_size);
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
std::string file_name = file_names[thread_id];
size_t length = lengths_vector[thread_id];
size_t offset = offsets_vector[thread_id];
torch::Tensor tensor = tensors_vector[thread_id];
char *addr = (char *)tensor.data_ptr();
threads[thread_id] =
std::thread(&SFCSFs::read_file_to_addr, this, file_name, cipher_info, addr, length, offset);
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id].join();
}
}
size_t SFCSFs::write_file_from_addr(std::string file_name, CipherInfo cipher_info, char *addr, size_t length,
size_t offset)
{
SFCSFile sfcs_file(file_name, this, cipher_info);
return sfcs_file.write_file_from_addr(addr, length, offset, false);
}
void SFCSFs::write_multi_files(pybind11::list file_paths, pybind11::list tensors, pybind11::list lengths,
pybind11::list offsets, int num_thread, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, size_t header_size)
{
vector<thread> threads(num_thread);
auto file_names = file_paths.cast<std::vector<std::string>>();
auto tensors_vector = tensors.cast<std::vector<torch::Tensor>>();
auto lengths_vector = lengths.cast<std::vector<size_t>>();
auto offsets_vector = offsets.cast<std::vector<size_t>>();
CipherInfo cipher_info = CipherInfo(use_cipher, key_arr, iv_arr, header_size);
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
std::string file_name = file_names[thread_id];
size_t length = lengths_vector[thread_id];
size_t offset = offsets_vector[thread_id];
torch::Tensor tensor = tensors_vector[thread_id];
char *addr = (char *)tensor.data_ptr();
threads[thread_id] =
std::thread(&SFCSFs::write_file_from_addr, this, file_name, cipher_info, addr, length, offset);
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id].join();
}
}
void SFCSFs::get_file_size(std::string file_name, size_t *size)
{
SFCSFile sfcs_file(file_name, this);
*size = sfcs_file.get_file_size();
}
void SFCSFs::get_multi_file_size(pybind11::list file_paths, pybind11::list sizes, int num_thread)
{
vector<thread> threads(num_thread);
auto file_names = file_paths.cast<std::vector<std::string>>();
vector<size_t> lengths(num_thread);
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
std::string file_name = file_names[thread_id];
threads[thread_id] = std::thread(&SFCSFs::get_file_size, this, file_name, &lengths[thread_id]);
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id].join();
sizes.append(lengths[thread_id]);
}
}
SFCSFile::SFCSFile(std::string path)
{
file_path = path;
sfcs_fs = new SFCSFs();
fs_owner = true;
fs = sfcs_fs->fs;
}
SFCSFile::SFCSFile(std::string path, SFCSFs *sfcs_fs)
{
file_path = path;
this->sfcs_fs = sfcs_fs;
fs_owner = false;
fs = sfcs_fs->fs;
}
SFCSFile::SFCSFile(std::string file_path, CipherInfo cipher_info) : SFCSFile(file_path) SFCSFile::SFCSFile(std::string file_path, CipherInfo cipher_info) : SFCSFile(file_path)
{ {
this->cipher_info = cipher_info; this->cipher_info = cipher_info;
} }
SFCSFile::SFCSFile(std::string file_path, SFCSFs *sfcs_fs, CipherInfo cipher_info) : SFCSFile(file_path, sfcs_fs)
{
this->cipher_info = cipher_info;
}
SFCSFile::SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr, SFCSFile::SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr, size_t header_size) pybind11::array_t<char> iv_arr, size_t header_size)
: SFCSFile(file_path) : SFCSFile(file_path)
...@@ -56,7 +226,10 @@ SFCSFile::SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<cha ...@@ -56,7 +226,10 @@ SFCSFile::SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<cha
SFCSFile::~SFCSFile() SFCSFile::~SFCSFile()
{ {
cfsDisconnect(fs); if (fs_owner)
{
delete sfcs_fs;
}
} }
size_t SFCSFile::get_file_size() size_t SFCSFile::get_file_size()
...@@ -66,16 +239,19 @@ size_t SFCSFile::get_file_size() ...@@ -66,16 +239,19 @@ size_t SFCSFile::get_file_size()
cfsFileInfo *file_info = cfsGetPathInfo(fs, file_path.c_str()); cfsFileInfo *file_info = cfsGetPathInfo(fs, file_path.c_str());
if (file_info == NULL) if (file_info == NULL)
{ {
logError("Failed to get path info of relative path", file_path, cfsGetLastError()); logWarn("Failed to get path info of relative path", file_path, cfsGetLastError());
cfsDisconnect(fs); cfsFreeFileInfo(file_info, 1);
throw std::runtime_error("SFCS Exception: get path info"); return 0;
} }
else
{
size = file_info->mSize; size = file_info->mSize;
cfsFreeFileInfo(file_info, 1); cfsFreeFileInfo(file_info, 1);
return size; return size;
}
} }
size_t SFCSFile::read_file(char *addr, size_t length, size_t offset) size_t SFCSFile::read_file_to_addr(char *addr, size_t length, size_t offset)
{ {
size_t count; size_t count;
int32_t ret; int32_t ret;
...@@ -131,7 +307,7 @@ size_t SFCSFile::read_file(char *addr, size_t length, size_t offset) ...@@ -131,7 +307,7 @@ size_t SFCSFile::read_file(char *addr, size_t length, size_t offset)
return length - count; return length - count;
} }
void SFCSFile::read_file_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size, void SFCSFile::read_file_to_address_thread(int thread_id, char *addr, int device_id, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset) size_t total_size, size_t global_offset)
{ {
size_t offset = thread_id * block_size; size_t offset = thread_id * block_size;
...@@ -143,17 +319,21 @@ void SFCSFile::read_file_thread(int thread_id, char *addr, int device_id, char * ...@@ -143,17 +319,21 @@ void SFCSFile::read_file_thread(int thread_id, char *addr, int device_id, char *
} }
// TODO: actual number of bytes read may be less than read_size // TODO: actual number of bytes read may be less than read_size
read_file(addr + offset, read_size, global_offset + offset); read_file_to_addr(addr + offset, read_size, global_offset + offset);
#if defined(USE_CUDA)
if (dev_mem != NULL && device_id >= 0) if (dev_mem != NULL && device_id >= 0)
{ {
cudaSetDevice(device_id); cudaSetDevice(device_id);
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice); cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
} }
#elif defined(USE_NPU)
#else
#endif
} }
size_t SFCSFile::read_file_parallel(char *addr, int device_id, char *dev_mem, int num_thread, size_t total_size, size_t SFCSFile::read_file_to_address_parallel(char *addr, int device_id, char *dev_mem, int num_thread,
size_t global_offset) size_t total_size, size_t global_offset)
{ {
vector<thread> threads(num_thread); vector<thread> threads(num_thread);
...@@ -170,8 +350,8 @@ size_t SFCSFile::read_file_parallel(char *addr, int device_id, char *dev_mem, in ...@@ -170,8 +350,8 @@ size_t SFCSFile::read_file_parallel(char *addr, int device_id, char *dev_mem, in
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
{ {
threads[thread_id] = std::thread(&SFCSFile::read_file_thread, this, thread_id, addr, device_id, dev_mem, threads[thread_id] = std::thread(&SFCSFile::read_file_to_address_thread, this, thread_id, addr, device_id,
block_size, total_size, global_offset); dev_mem, block_size, total_size, global_offset);
} }
for (int thread_id = 0; thread_id < num_thread; thread_id++) for (int thread_id = 0; thread_id < num_thread; thread_id++)
...@@ -186,28 +366,49 @@ size_t SFCSFile::read_file_to_array(pybind11::array_t<char> arr, size_t length, ...@@ -186,28 +366,49 @@ size_t SFCSFile::read_file_to_array(pybind11::array_t<char> arr, size_t length,
{ {
pybind11::buffer_info buf_info = arr.request(); pybind11::buffer_info buf_info = arr.request();
char *addr = static_cast<char *>(buf_info.ptr); char *addr = static_cast<char *>(buf_info.ptr);
return read_file_parallel(addr, -1, NULL, num_thread, length, offset); madvise(addr, length, MADV_HUGEPAGE);
return read_file_to_address_parallel(addr, -1, NULL, num_thread, length, offset);
} }
size_t SFCSFile::write_file(char *addr, size_t length) size_t SFCSFile::write_file_from_addr(char *addr, size_t length, size_t offset, bool append)
{ {
size_t count; size_t count;
int32_t ret; int32_t ret;
char *dst; char *dst;
if (append)
offset = get_file_size();
if (cipher_info.use_cipher) if (cipher_info.use_cipher)
{ {
size_t h_off = cipher_info.header_size; size_t h_off = cipher_info.header_size;
int cipher_ret;
if (append == false && offset == 0)
{
CtrEncrypter enc(cipher_info.mode, cipher_info.key, cipher_info.iv, 0); CtrEncrypter enc(cipher_info.mode, cipher_info.key, cipher_info.iv, 0);
unsigned char *pt = reinterpret_cast<unsigned char *>(addr); unsigned char *pt = reinterpret_cast<unsigned char *>(addr);
int cipher_ret = enc.encrypt_update(pt + h_off, length - h_off, pt + h_off); cipher_ret = enc.encrypt_update(pt + h_off, length - h_off, pt + h_off);
}
else
{
CtrEncrypter enc(cipher_info.mode, cipher_info.key, cipher_info.iv, offset - h_off);
unsigned char *pt = reinterpret_cast<unsigned char *>(addr);
cipher_ret = enc.encrypt_update(pt, length, pt);
}
if (!cipher_ret) if (!cipher_ret)
{ {
throw std::runtime_error("Cipher Exception: encrypt fail"); throw std::runtime_error("Cipher Exception: encrypt fail");
} }
} }
cfsFile file = cfsOpenFile(fs, file_path.c_str(), O_WRONLY | O_ASYNC, 0, 0, 0); cfsFile file;
if (append)
file = cfsOpenFileAcc(fs, file_path.c_str(), O_WRONLY | O_ASYNC | O_APPEND, 0644, false, true);
else
file = cfsOpenFileAcc(fs, file_path.c_str(), O_WRONLY | O_ASYNC, 0644, false, false);
if (file == NULL) if (file == NULL)
{ {
logError("Failed to open file", file_path, cfsGetLastError()); logError("Failed to open file", file_path, cfsGetLastError());
...@@ -236,11 +437,77 @@ size_t SFCSFile::write_file(char *addr, size_t length) ...@@ -236,11 +437,77 @@ size_t SFCSFile::write_file(char *addr, size_t length)
return length - count; return length - count;
} }
size_t SFCSFile::write_file_from_array(pybind11::array_t<char> arr, size_t length) size_t SFCSFile::write_file_from_array(pybind11::array_t<char> arr, size_t length, bool append)
{ {
pybind11::buffer_info buf_info = arr.request(); pybind11::buffer_info buf_info = arr.request();
char *addr = static_cast<char *>(buf_info.ptr); char *addr = static_cast<char *>(buf_info.ptr);
return write_file(addr, length); return write_file_from_addr(addr, length, 0, append);
}
void SFCSFile::write_file_from_tensor(torch::Tensor tensor, size_t length, size_t offset, std::string file_name)
{
char *buf, *addr;
buf = (char *)mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);
madvise(buf, length, MADV_HUGEPAGE);
if (tensor.device().is_cuda())
{
#if defined(USE_CUDA)
cudaSetDevice(tensor.device().index());
cudaMemcpyAsync(buf, (char *)tensor.data_ptr(), length, cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
addr = buf;
#endif
}
else if (cipher_info.use_cipher)
{
memcpy(buf, (char *)tensor.data_ptr(), length);
addr = buf;
}
else
{
addr = (char *)tensor.data_ptr();
}
SFCSFile sfcs_file(file_name, sfcs_fs, cipher_info);
sfcs_file.write_file_from_addr(addr, length, offset, false);
munmap(buf, length);
}
size_t SFCSFile::write_file_from_tensors(pybind11::list tensors, pybind11::list sizes, pybind11::list offsets,
std::string concat_dir, std::string concat_file)
{
int num_thread = tensors.size();
size_t length = 0;
vector<thread> threads(num_thread);
vector<std::string> file_names;
vector<const char *> file_paths;
auto tensors_vector = tensors.cast<std::vector<torch::Tensor>>();
auto sizes_vector = sizes.cast<std::vector<size_t>>();
auto offsets_vector = offsets.cast<std::vector<size_t>>();
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
torch::Tensor tensor = tensors_vector[thread_id];
size_t size = sizes_vector[thread_id];
size_t offset = offsets_vector[thread_id];
file_names.push_back(concat_dir + std::string("/") + std::to_string(thread_id));
threads[thread_id] =
std::thread(&SFCSFile::write_file_from_tensor, this, tensor, size, offset, file_names[thread_id]);
file_paths.push_back(file_names[thread_id].c_str());
length += size;
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id].join();
}
sfcs_fs->concat_files(concat_file, file_paths);
sfcs_fs->rename_file(concat_file.c_str(), file_path.c_str());
return length;
} }
void SFCSFile::delete_file() void SFCSFile::delete_file()
......
'''
Copyright (c) 2024 Beijing Volcano Engine Technology Ltd.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
import json
import os
from typing import Dict, Optional
import numpy as np
import torch
from loguru import logger
from safetensors.torch import save_file as safetensors_save_file
from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header, encrypt
from veturboio.ops.sfcs_utils import sfcs_delete_file, sfcs_write_file, sfcs_write_file_in_parallel
from veturboio.safetensors import parse_state_dict
from veturboio.types import FILE_PATH
try:
import veturboio_ext
IOHelper = veturboio_ext.IOHelper
except ImportError:
IOHelper = None
logger.warning("veturboio_ext not found, fallback to pure python implementation")
def load_file_to_tensor(
file_path: str,
total_tensor: torch.Tensor,
offset: int,
helper: IOHelper,
length: int = 0,
device_id: Optional[int] = -1,
num_thread: Optional[int] = 32,
use_pinmem: Optional[bool] = False,
use_sfcs_sdk: Optional[bool] = False,
use_direct_io: Optional[bool] = False,
cipher_info: CipherInfo = CipherInfo(False),
) -> torch.Tensor:
return helper.load_file_to_tensor(
file_path,
total_tensor,
length,
offset,
device_id,
num_thread,
use_pinmem,
use_sfcs_sdk,
use_direct_io,
cipher_info.use_cipher,
cipher_info.key,
cipher_info.iv,
CipherInfo.HEADER_SIZE if cipher_info.use_header else 0,
)
def save_tensor_to_file(
tensor: torch.Tensor,
file_path: FILE_PATH,
length: int,
helper: IOHelper,
use_pinmem: Optional[bool] = False,
use_sfcs_sdk: Optional[bool] = False,
cipher_info: CipherInfo = CipherInfo(False),
):
return helper.save_tensor_to_file(
tensor,
file_path,
length,
use_pinmem,
use_sfcs_sdk,
cipher_info.use_cipher,
cipher_info.key,
cipher_info.iv,
CipherInfo.HEADER_SIZE if cipher_info.use_header else 0,
)
def save_file(
state_dict: Dict[str, torch.Tensor],
filename: FILE_PATH,
helper: IOHelper,
metadata: Optional[Dict[str, str]] = None,
use_sfcs_sdk: bool = False,
cipher_info: CipherInfo = CipherInfo(False),
):
if helper is None:
if cipher_info.use_cipher:
logger.warning("helper is None, cipher is not supported in pure python implementation")
return safetensors_save_file(state_dict, filename, metadata=metadata)
meta, tensors, sizes, offsets = parse_state_dict(state_dict)
if metadata:
meta["__metadata__"] = metadata
meta_bytes = json.dumps(meta).encode('utf-8')
meta_len = len(meta_bytes)
# alignment
if not meta_len % 8 == 0:
meta_len_pad = (meta_len + 8) // 8 * 8
meta_bytes += b' ' * (meta_len_pad - meta_len)
meta_len = meta_len_pad
st_header_bytes = meta_len.to_bytes(8, 'little') + meta_bytes
st_header_len = len(st_header_bytes)
if use_sfcs_sdk:
sfcs_write_file_in_parallel(filename, tensors, sizes, offsets, st_header_bytes, st_header_len, cipher_info)
else:
with open(filename, "wb") as f:
if cipher_info.use_cipher:
if cipher_info.use_header:
cipher_header_bytes = cipher_info.to_header_bytes()
f.write(cipher_header_bytes)
enc_st_header_arr = np.zeros(st_header_len, dtype=np.uint8)
encrypt(cipher_info, np.frombuffer(st_header_bytes, dtype=np.uint8), enc_st_header_arr, 0)
f.write(enc_st_header_arr.tobytes())
else:
f.write(st_header_bytes)
for i in range(len(tensors)):
tensor = tensors[i]
size = sizes[i]
save_tensor_to_file(
tensor,
filename,
size,
helper=helper,
use_pinmem=False,
use_sfcs_sdk=use_sfcs_sdk,
cipher_info=cipher_info,
)
def init_io_helper() -> IOHelper:
return IOHelper()
...@@ -16,7 +16,7 @@ limitations under the License. ...@@ -16,7 +16,7 @@ limitations under the License.
from typing import Optional from typing import Optional
import torch import numpy as np
from loguru import logger from loguru import logger
from veturboio.ops.cipher import CipherInfo from veturboio.ops.cipher import CipherInfo
...@@ -26,40 +26,26 @@ try: ...@@ -26,40 +26,26 @@ try:
veturboio_ext = load_veturboio_ext() veturboio_ext = load_veturboio_ext()
IOHelper = veturboio_ext.IOHelper IOHelper = veturboio_ext.IOHelper
POSIXFile = veturboio_ext.POSIXFile
except ImportError: except ImportError:
IOHelper = None POSIXFile = None
logger.warning("veturboio_ext not found, fallback to pure python implementation") logger.warning("veturboio_ext not found, fallback to pure python implementation")
def load_file_to_tensor( def posix_read_file(
file_path: str, file_path: str,
total_tensor: torch.Tensor, arr: np.ndarray,
sample_tensor: torch.Tensor, length: int,
offset: int, offset: int,
helper: IOHelper, num_thread: Optional[int] = 1,
device_id: Optional[int] = -1,
num_thread: Optional[int] = 32,
use_pinmem: Optional[bool] = False,
use_sfcs_sdk: Optional[bool] = False,
use_direct_io: Optional[bool] = False,
cipher_info: CipherInfo = CipherInfo(False), cipher_info: CipherInfo = CipherInfo(False),
) -> torch.Tensor: use_direct_io: bool = False,
return helper.load_file_to_tensor( ) -> int:
posix_file = POSIXFile(
file_path, file_path,
total_tensor,
sample_tensor,
offset,
device_id,
num_thread,
use_pinmem,
use_sfcs_sdk,
use_direct_io,
cipher_info.use_cipher, cipher_info.use_cipher,
cipher_info.key, cipher_info.key,
cipher_info.iv, cipher_info.iv,
CipherInfo.HEADER_SIZE if cipher_info.use_header else 0, CipherInfo.HEADER_SIZE if cipher_info.use_header else 0,
) )
return posix_file.read_file_to_array(arr, length, offset, num_thread, use_direct_io)
def init_io_helper() -> IOHelper:
return IOHelper()
This diff is collapsed.
...@@ -17,7 +17,8 @@ limitations under the License. ...@@ -17,7 +17,8 @@ limitations under the License.
import json import json
import os import os
import pprint import pprint
from typing import Callable, Dict, List from multiprocessing import shared_memory
from typing import Callable, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -97,7 +98,7 @@ class TensorMeta: ...@@ -97,7 +98,7 @@ class TensorMeta:
class SafetensorsFile: class SafetensorsFile:
def __init__(self, file: FILE_PATH, loader: BaseLoader, use_cipher: bool = False) -> None: def __init__(self, file: FILE_PATH, loader: BaseLoader, use_cipher: Optional[bool] = None) -> None:
self._file = file self._file = file
self._loader = loader self._loader = loader
...@@ -105,9 +106,9 @@ class SafetensorsFile: ...@@ -105,9 +106,9 @@ class SafetensorsFile:
# cipher related # cipher related
self._cipher_info = CipherInfo(False) self._cipher_info = CipherInfo(False)
if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1": if use_cipher == True or use_cipher == None and os.getenv("VETURBOIO_USE_CIPHER", "0") == "1":
header_bytes = loader.load_to_bytes(offset=0, count=CipherInfo.HEADER_SIZE) header_bytes = loader.load_to_bytes(offset=0, count=CipherInfo.HEADER_SIZE)
self._cipher_info = CipherInfo(True, header_bytes) self._cipher_info = CipherInfo(True, header_bytes, os.path.abspath(self.file))
if self._cipher_info.use_header: if self._cipher_info.use_header:
h_off = CipherInfo.HEADER_SIZE h_off = CipherInfo.HEADER_SIZE
...@@ -206,8 +207,67 @@ class SafetensorsFile: ...@@ -206,8 +207,67 @@ class SafetensorsFile:
def __repr__(self) -> str: def __repr__(self) -> str:
return self.__str__() return self.__str__()
def load(self, map_location: str = "cpu") -> Dict[str, torch.Tensor]: def load(self, map_location: str = "cpu", state_dict: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]:
if not self._is_valid: if not self._is_valid:
return self._loader.load_pt(map_location, self._cipher_info) return self._loader.load_pt(map_location, self._cipher_info)
else: else:
return self._loader.load_safetensors(self, map_location) return self._loader.load_safetensors(self, map_location, state_dict)
def load_to_shmem(self) -> shared_memory.SharedMemory:
return self._loader.load_to_shmem(self._cipher_info)
def parse_state_dict(state_dict: Dict[str, torch.Tensor]):
meta = {}
tensors = []
sizes = []
offsets = []
data_offset_begin = 0
data_offset_end = 0
_safetensors_dtype_str = {v: k for k, v in _safetensors_dtype_mapper.items()}
bool_state_dict = {}
for key, tensor in state_dict.items():
if tensor.dtype == torch.bool:
bool_state_dict[key] = tensor
continue
else:
size = 1
for d in range(tensor.dim()):
size *= tensor.shape[d]
try:
bytes = torch.finfo(tensor.dtype).bits // 8
except:
bytes = torch.iinfo(tensor.dtype).bits // 8
size *= bytes
data_offset_end = data_offset_begin + size
meta[key] = {
"dtype": _safetensors_dtype_str[tensor.dtype],
"shape": tensor.shape,
"data_offsets": [data_offset_begin, data_offset_end],
}
if size > 0:
tensors.append(tensor)
sizes.append(size)
offsets.append(data_offset_begin)
data_offset_begin = data_offset_end
for key, tensor in bool_state_dict.items():
size = 1
for d in range(tensor.dim()):
size *= tensor.shape[d]
data_offset_end = data_offset_begin + size
meta[key] = {
"dtype": _safetensors_dtype_str[tensor.dtype],
"shape": tensor.shape,
"data_offsets": [data_offset_begin, data_offset_end],
}
if size > 0:
tensors.append(tensor)
sizes.append(size)
offsets.append(data_offset_begin)
data_offset_begin = data_offset_end
return meta, tensors, sizes, offsets
...@@ -24,6 +24,8 @@ from safetensors.torch import save_file as safetenors_save_file ...@@ -24,6 +24,8 @@ from safetensors.torch import save_file as safetenors_save_file
from safetensors.torch import save_model as safetensors_save_model from safetensors.torch import save_model as safetensors_save_model
from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header, encrypt from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header, encrypt
from veturboio.ops.io_utils import IOHelper
from veturboio.ops.io_utils import save_file as fast_save_file
from veturboio.types import FILE_PATH from veturboio.types import FILE_PATH
...@@ -39,17 +41,30 @@ class BaseSaver: ...@@ -39,17 +41,30 @@ class BaseSaver:
class PosixSaver(BaseSaver): class PosixSaver(BaseSaver):
def __init__(self, file: FILE_PATH, use_cipher: bool = False) -> None: def __init__(self, file: FILE_PATH, helper: IOHelper = None, use_cipher: bool = False) -> None:
super().__init__(method="posix") super().__init__(method="posix")
self.file = file self.file = file
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1" use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1" use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
if use_header: if use_header:
self.cipher_info = create_cipher_with_header(CipherMode.CTR_128) self.cipher_info = create_cipher_with_header(CipherMode.CTR_128, os.path.abspath(self.file))
else: else:
self.cipher_info = CipherInfo(use_cipher) self.cipher_info = CipherInfo(use_cipher, None, os.path.abspath(self.file))
def save_file(self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None) -> None: self.helper = helper
def save_file(
self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None, enable_fast_mode: bool = False
) -> None:
if enable_fast_mode:
fast_save_file(
state_dict,
self.file,
helper=self.helper,
metadata=metadata,
cipher_info=self.cipher_info,
)
else:
if self.cipher_info.use_cipher: if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name tmp_file_path = tmpfile.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment