Commit 66b20b34 authored by 胡腾's avatar 胡腾 Committed by huteng.ht
Browse files

feat(sfcs): reduce memcpy

* feat(sfcs): reduce memcpy
* chore: rename load_to_bytes_array to load_to_bytes
parent b1809ef9
......@@ -34,9 +34,9 @@ class BaseLoader:
def __init__(self, method: str) -> None:
self.method = method
def load_to_bytes_array(
def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray:
) -> bytes:
raise NotImplementedError
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
......@@ -68,14 +68,14 @@ class PosixLoader(BaseLoader):
def __init__(self) -> None:
super().__init__(method="posix")
def load_to_bytes_array(
def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray:
) -> bytes:
arr = np.fromfile(file, dtype=np.uint8, offset=offset, count=count)
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
decrypt(cipher_info, arr, arr, offset - h_off)
return arr
return arr.tobytes()
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
state_dict = {}
......
......@@ -47,17 +47,19 @@ class SfcsClientLoader(BaseLoader):
init_sfcs_conf()
def load_to_bytes_array(
def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray:
) -> bytes:
file_size = sfcs_get_file_size(file)
if offset + count > file_size:
count = file_size - offset
candidate = np.empty([count], dtype=np.byte)
file_bytes = bytes(count)
candidate = np.frombuffer(file_bytes, dtype=np.byte)
sfcs_read_file(
file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info
)
return candidate
return file_bytes
def load_safetensors(
self, safetensors_file: SafetensorsFile, map_location: str = "cpu"
......@@ -91,5 +93,5 @@ class SfcsClientLoader(BaseLoader):
) -> Dict[str, torch.Tensor]:
file_size = sfcs_get_file_size(file)
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
file_bytes = self.load_to_bytes_array(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
return torch.load(BytesIO(file_bytes.data), map_location=map_location)
file_bytes = self.load_to_bytes(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
return torch.load(BytesIO(file_bytes), map_location=map_location)
......@@ -106,7 +106,7 @@ class SafetensorsFile:
# cipher related
self._cipher_info = CipherInfo(False)
if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1":
header_bytes = loader.load_to_bytes_array(file, offset=0, count=CipherInfo.HEADER_SIZE).tobytes()
header_bytes = loader.load_to_bytes(file, offset=0, count=CipherInfo.HEADER_SIZE)
self._cipher_info = CipherInfo(True, header_bytes)
if self._cipher_info.use_header:
......@@ -114,18 +114,16 @@ class SafetensorsFile:
else:
h_off = 0
magic_number = loader.load_to_bytes_array(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
magic_number = loader.load_to_bytes(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
if magic_number != SAFETENSORS_FILE_MAGIC_NUM:
self._is_valid = False
return
self._meta_size = np.frombuffer(
loader.load_to_bytes_array(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
loader.load_to_bytes(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
)[0]
meta_bytes = loader.load_to_bytes_array(
file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info
)
meta_dict = json.loads(meta_bytes.tobytes().decode("utf-8"))
meta_bytes = loader.load_to_bytes(file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info)
meta_dict = json.loads(meta_bytes.decode("utf-8"))
self._shared_tensor = {}
self._ignored_meta = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment