You need to sign in or sign up before continuing.
Commit e5edc542 authored by 刘乐典's avatar 刘乐典 Committed by huteng.ht
Browse files

feat(security): compat with cipher header and use cipher in posix

* feat(security): compat with cipher header and use cipher in posix
* feat(security): compat with cipher header and use cipher in posix
parent 0cedada8
......@@ -14,9 +14,11 @@
* limitations under the License.
*/
#include "include/load_utils.h"
#include "include/cipher.h"
#include "include/fastcrypto.h"
void read_file_thread_fread(int thread_id, string file_path, char *addr, char *dev_mem, size_t block_size,
size_t total_size, size_t global_offset, bool use_direct_io)
size_t total_size, size_t global_offset, bool use_direct_io, CipherInfo cipher_info)
{
size_t offset = thread_id * block_size;
size_t read_size = block_size;
......@@ -39,6 +41,19 @@ void read_file_thread_fread(int thread_id, string file_path, char *addr, char *d
fread(addr + offset, 1, read_size, fp);
fclose(fp);
// Decrypt if use_cipher is true
if (cipher_info.use_cipher)
{
CtrDecrypter dec(cipher_info.mode, cipher_info.key, cipher_info.iv,
global_offset + offset - cipher_info.header_size);
unsigned char *ct = reinterpret_cast<unsigned char *>(addr + offset);
int cipher_ret = dec.decrypt_update(ct, read_size, ct);
if (!cipher_ret)
{
throw std::runtime_error("Cipher Exception: decrypt fail");
}
}
if (dev_mem != NULL)
cudaMemcpyAsync(dev_mem + offset, addr + offset, read_size, cudaMemcpyHostToDevice);
}
......@@ -69,7 +84,7 @@ void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size
for (int thread_id = 0; thread_id < num_thread; thread_id++)
{
threads[thread_id] = std::thread(read_file_thread_fread, thread_id, file_path, addr, dev_mem, block_size,
total_size, global_offset, use_direct_io);
total_size, global_offset, use_direct_io, cipher_info);
}
for (int thread_id = 0; thread_id < num_thread; thread_id++)
......
......@@ -15,6 +15,7 @@
*/
#include "include/io_helper.h"
#include "include/sfcs.h"
#include "include/cipher.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
......@@ -22,9 +23,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::class_<SFCSFile>(m, "SFCSFile")
.def(py::init<std::string>())
.def(py::init<std::string, bool, pybind11::array_t<char>, pybind11::array_t<char>>())
.def(py::init<std::string, bool, pybind11::array_t<char>, pybind11::array_t<char>, size_t>())
.def("get_file_size", &SFCSFile::get_file_size)
.def("read_file_to_array", &SFCSFile::read_file_to_array)
.def("write_file_from_array", &SFCSFile::write_file_from_array)
.def("delete_file", &SFCSFile::delete_file);
}
\ No newline at end of file
py::class_<CtrEncWrap>(m, "CtrEncWrap")
.def(py::init<std::string, pybind11::array_t<unsigned char>, pybind11::array_t<unsigned char>, size_t>())
.def("encrypt_update", &CtrEncWrap::encrypt_update);
py::class_<CtrDecWrap>(m, "CtrDecWrap")
.def(py::init<std::string, pybind11::array_t<unsigned char>, pybind11::array_t<unsigned char>, size_t>())
.def("decrypt_update", &CtrDecWrap::decrypt_update);
}
......@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "include/sfcs.h"
#include "include/cipher.h"
#include "include/fastcrypto.h"
SFCSFile::SFCSFile(std::string path)
......@@ -47,10 +48,10 @@ SFCSFile::SFCSFile(std::string file_path, CipherInfo cipher_info) : SFCSFile(fil
}
SFCSFile::SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr)
pybind11::array_t<char> iv_arr, size_t header_size)
: SFCSFile(file_path)
{
this->cipher_info = CipherInfo(use_cipher, key_arr, iv_arr);
this->cipher_info = CipherInfo(use_cipher, key_arr, iv_arr, header_size);
}
SFCSFile::~SFCSFile()
......@@ -118,7 +119,7 @@ size_t SFCSFile::read_file(char *addr, size_t length, size_t offset)
// Decrypt if use_cipher is true
if (cipher_info.use_cipher)
{
CtrDecrypter dec(cipher_info.key, cipher_info.iv, offset);
CtrDecrypter dec(cipher_info.mode, cipher_info.key, cipher_info.iv, offset - cipher_info.header_size);
unsigned char *ct = reinterpret_cast<unsigned char *>(addr);
int cipher_ret = dec.decrypt_update(ct, length - count, ct);
if (!cipher_ret)
......@@ -141,6 +142,7 @@ void SFCSFile::read_file_thread(int thread_id, char *addr, char *dev_mem, size_t
read_size = (total_size > offset) ? total_size - offset : 0;
}
// TODO: actual number of bytes read may be less than read_size
read_file(addr + offset, read_size, global_offset + offset);
if (dev_mem != NULL)
......@@ -191,9 +193,10 @@ size_t SFCSFile::write_file(char *addr, size_t length)
if (cipher_info.use_cipher)
{
CtrEncrypter enc(cipher_info.key, cipher_info.iv, 0);
size_t h_off = cipher_info.header_size;
CtrEncrypter enc(cipher_info.mode, cipher_info.key, cipher_info.iv, 0);
unsigned char *pt = reinterpret_cast<unsigned char *>(addr);
int cipher_ret = enc.encrypt_update(pt, length, pt);
int cipher_ret = enc.encrypt_update(pt + h_off, length - h_off, pt + h_off);
if (!cipher_ret)
{
throw std::runtime_error("Cipher Exception: encrypt fail");
......@@ -239,7 +242,6 @@ size_t SFCSFile::write_file_from_array(pybind11::array_t<char> arr, size_t lengt
void SFCSFile::delete_file()
{
int ret;
ret = cfsDelete(fs, file_path.c_str(), 1);
if (ret == -1)
{
......@@ -247,23 +249,3 @@ void SFCSFile::delete_file()
throw std::runtime_error("SFCS Exception: delete file");
}
}
CipherInfo::CipherInfo(bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr)
{
this->use_cipher = use_cipher;
if (use_cipher)
{
pybind11::buffer_info key_info = key_arr.request();
if ((size_t)key_info.size != CTR_BLOCK_SIZE)
{
throw std::runtime_error("Cipher Exception: key length invalid");
}
key = reinterpret_cast<unsigned char *>(key_info.ptr);
pybind11::buffer_info iv_info = iv_arr.request();
if ((size_t)iv_info.size != CTR_BLOCK_SIZE)
{
throw std::runtime_error("Cipher Exception: iv length invalid");
}
iv = reinterpret_cast<unsigned char *>(iv_info.ptr);
}
}
\ No newline at end of file
......@@ -57,6 +57,7 @@ def load_file_to_tensor(
cipher_info.use_cipher,
cipher_info.key,
cipher_info.iv,
CipherInfo.HEADER_SIZE if cipher_info.use_header else 0,
)
......
......@@ -91,11 +91,9 @@ class CredentialsHelper:
self.thread = threading.Thread(target=self.refresh_loop, daemon=True)
self.stop_flag = threading.Event()
self.client = DataPipeClient()
if not self.client.session:
raise RuntimeError('Datapipe client initialization failed in credentials helper')
self.sfcs_conf_path = sfcs_conf_path
if not self.do_refresh():
raise RuntimeError('Credentials helper do refresh failed')
raise RuntimeError('Credentials helper first fetch failed')
self.thread.start()
self.running = True
logger.info('CredentialsHelper refresh thread strat')
......@@ -240,12 +238,24 @@ def sfcs_read_file(
num_thread: Optional[int] = 1,
cipher_info: CipherInfo = CipherInfo(False),
) -> int:
sfcs_file = SFCSFile(file_path, cipher_info.use_cipher, cipher_info.key, cipher_info.iv)
sfcs_file = SFCSFile(
file_path,
cipher_info.use_cipher,
cipher_info.key,
cipher_info.iv,
CipherInfo.HEADER_SIZE if cipher_info.use_header else 0,
)
return sfcs_file.read_file_to_array(arr, length, offset, num_thread)
def sfcs_write_file(file_path: str, arr: np.ndarray, length: int, cipher_info: CipherInfo = CipherInfo(False)) -> int:
sfcs_file = SFCSFile(file_path, cipher_info.use_cipher, cipher_info.key, cipher_info.iv)
sfcs_file = SFCSFile(
file_path,
cipher_info.use_cipher,
cipher_info.key,
cipher_info.iv,
CipherInfo.HEADER_SIZE if cipher_info.use_header else 0,
)
return sfcs_file.write_file_from_array(arr, length)
......
......@@ -15,6 +15,7 @@ limitations under the License.
'''
import json
import os
import pprint
from typing import Callable, Dict, List
......@@ -23,6 +24,7 @@ import torch
from loguru import logger
from veturboio.loader import BaseLoader
from veturboio.ops.cipher import CipherInfo
from veturboio.types import FILE_PATH
# All safetensors file will start with a json string, which is the meta info of the file.
......@@ -95,18 +97,34 @@ class TensorMeta:
class SafetensorsFile:
def __init__(self, file: FILE_PATH, loader: BaseLoader) -> None:
def __init__(self, file: FILE_PATH, loader: BaseLoader, use_cipher: bool = False) -> None:
self._file = file
self._loader = loader
self._is_valid = True
magic_number = loader.load_to_bytes_array(file, offset=8, count=1)[0]
# cipher related
self._cipher_info = CipherInfo(False)
if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1":
header_bytes = loader.load_to_bytes_array(file, offset=0, count=CipherInfo.HEADER_SIZE).tobytes()
self._cipher_info = CipherInfo(True, header_bytes)
if self._cipher_info.use_header:
h_off = CipherInfo.HEADER_SIZE
else:
h_off = 0
magic_number = loader.load_to_bytes_array(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
if magic_number != SAFETENSORS_FILE_MAGIC_NUM:
self._is_valid = False
return
self._meta_size = np.frombuffer(loader.load_to_bytes_array(file, offset=0, count=8), dtype=np.int64)[0]
meta_bytes = loader.load_to_bytes_array(file, offset=8, count=self._meta_size)
self._meta_size = np.frombuffer(
loader.load_to_bytes_array(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
)[0]
meta_bytes = loader.load_to_bytes_array(
file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info
)
meta_dict = json.loads(meta_bytes.tobytes().decode("utf-8"))
self._shared_tensor = {}
......@@ -129,7 +147,7 @@ class SafetensorsFile:
)
# record the offset of the tensor data
self._tensor_offset = np.dtype(np.int64).itemsize + self._meta_size
self._tensor_offset = np.dtype(np.int64).itemsize + self._meta_size + h_off
@staticmethod
def split_tensor_to_state_dict(
......@@ -192,6 +210,6 @@ class SafetensorsFile:
def load(self, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
if not self._is_valid:
return self._loader.load_pt(self.file, map_location)
return self._loader.load_pt(self.file, map_location, self._cipher_info)
else:
return self._loader.load_safetensors(self, map_location)
......@@ -14,12 +14,16 @@ See the License for the specific language governing permissions and
limitations under the License.
'''
import os
import tempfile
from typing import Any, Dict
import numpy as np
import torch
from safetensors.torch import save_file as safetenors_save_file
from safetensors.torch import save_model as safetensors_save_model
from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header, encrypt
from veturboio.types import FILE_PATH
......@@ -35,14 +39,59 @@ class BaseSaver:
class PosixSaver(BaseSaver):
def __init__(self) -> None:
def __init__(self, use_cipher: bool = False) -> None:
super().__init__(method="posix")
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
if use_header:
self.cipher_info = create_cipher_with_header(CipherMode.CTR_128)
else:
self.cipher_info = CipherInfo(use_cipher)
def save_file(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH, metadata: Dict[str, str] = None) -> None:
safetenors_save_file(state_dict, file, metadata=metadata)
if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name
safetenors_save_file(state_dict, tmp_file_path, metadata=metadata)
tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush()
else:
safetenors_save_file(state_dict, file, metadata=metadata)
def save_model(self, model: torch.nn.Module, file: FILE_PATH) -> None:
return safetensors_save_model(model, file)
if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name
safetensors_save_model(model, tmp_file_path)
tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush()
else:
safetensors_save_model(model, file)
def save_pt(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH) -> None:
return torch.save(state_dict, file)
if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name
torch.save(state_dict, tmp_file_path)
tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush()
else:
torch.save(state_dict, file)
......@@ -23,7 +23,7 @@ import torch
from safetensors.torch import save_file as safetenors_save_file
from safetensors.torch import save_model as safetensors_save_model
from veturboio.ops.cipher import CipherInfo
from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header
from veturboio.ops.sfcs_utils import init_sfcs_conf, sfcs_get_file_size, sfcs_write_file
from veturboio.saver.base_saver import BaseSaver
from veturboio.types import FILE_PATH
......@@ -35,8 +35,12 @@ class SfcsClientSaver(BaseSaver):
init_sfcs_conf()
use_cipher = use_cipher or os.environ.get("VETURBOIO_USE_CIPHER", "0") == "1"
self.cipher_info = CipherInfo(use_cipher)
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
if use_header:
self.cipher_info = create_cipher_with_header(CipherMode.CTR_128)
else:
self.cipher_info = CipherInfo(use_cipher)
def save_file(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH, metadata: Dict[str, str] = None) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
......@@ -44,9 +48,14 @@ class SfcsClientSaver(BaseSaver):
safetenors_save_file(state_dict, file_path, metadata=metadata)
file_size = os.path.getsize(file_path)
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, file_size, self.cipher_info)
if self.cipher_info.use_header:
h_off = CipherInfo.HEADER_SIZE
file_bytes = np.empty(file_size + h_off, dtype=np.byte)
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.byte)
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info)
def save_model(self, model: torch.nn.Module, file: FILE_PATH) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
......@@ -54,9 +63,14 @@ class SfcsClientSaver(BaseSaver):
safetensors_save_model(model, file_path)
file_size = os.path.getsize(file_path)
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, file_size, self.cipher_info)
if self.cipher_info.use_header:
h_off = CipherInfo.HEADER_SIZE
file_bytes = np.empty(file_size + h_off, dtype=np.byte)
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.byte)
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info)
def save_pt(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
......@@ -64,6 +78,11 @@ class SfcsClientSaver(BaseSaver):
torch.save(state_dict, file_path)
file_size = os.path.getsize(file_path)
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, file_size, self.cipher_info)
if self.cipher_info.use_header:
h_off = CipherInfo.HEADER_SIZE
file_bytes = np.empty(file_size + h_off, dtype=np.byte)
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.byte)
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment