Commit 66b20b34 authored by 胡腾's avatar 胡腾 Committed by huteng.ht
Browse files

feat(sfcs): reduce memcpy

* feat(sfcs): reduce memcpy
* chore: rename load_to_bytes_array to load_to_bytes
parent b1809ef9
...@@ -34,9 +34,9 @@ class BaseLoader: ...@@ -34,9 +34,9 @@ class BaseLoader:
def __init__(self, method: str) -> None: def __init__(self, method: str) -> None:
self.method = method self.method = method
def load_to_bytes_array( def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False) self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray: ) -> bytes:
raise NotImplementedError raise NotImplementedError
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]: def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
...@@ -68,14 +68,14 @@ class PosixLoader(BaseLoader): ...@@ -68,14 +68,14 @@ class PosixLoader(BaseLoader):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(method="posix") super().__init__(method="posix")
def load_to_bytes_array( def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False) self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray: ) -> bytes:
arr = np.fromfile(file, dtype=np.uint8, offset=offset, count=count) arr = np.fromfile(file, dtype=np.uint8, offset=offset, count=count)
if cipher_info.use_cipher: if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
decrypt(cipher_info, arr, arr, offset - h_off) decrypt(cipher_info, arr, arr, offset - h_off)
return arr return arr.tobytes()
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]: def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
state_dict = {} state_dict = {}
......
...@@ -47,17 +47,19 @@ class SfcsClientLoader(BaseLoader): ...@@ -47,17 +47,19 @@ class SfcsClientLoader(BaseLoader):
init_sfcs_conf() init_sfcs_conf()
def load_to_bytes_array( def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False) self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray: ) -> bytes:
file_size = sfcs_get_file_size(file) file_size = sfcs_get_file_size(file)
if offset + count > file_size: if offset + count > file_size:
count = file_size - offset count = file_size - offset
candidate = np.empty([count], dtype=np.byte)
file_bytes = bytes(count)
candidate = np.frombuffer(file_bytes, dtype=np.byte)
sfcs_read_file( sfcs_read_file(
file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info
) )
return candidate return file_bytes
def load_safetensors( def load_safetensors(
self, safetensors_file: SafetensorsFile, map_location: str = "cpu" self, safetensors_file: SafetensorsFile, map_location: str = "cpu"
...@@ -91,5 +93,5 @@ class SfcsClientLoader(BaseLoader): ...@@ -91,5 +93,5 @@ class SfcsClientLoader(BaseLoader):
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
file_size = sfcs_get_file_size(file) file_size = sfcs_get_file_size(file)
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
file_bytes = self.load_to_bytes_array(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info) file_bytes = self.load_to_bytes(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
return torch.load(BytesIO(file_bytes.data), map_location=map_location) return torch.load(BytesIO(file_bytes), map_location=map_location)
...@@ -106,7 +106,7 @@ class SafetensorsFile: ...@@ -106,7 +106,7 @@ class SafetensorsFile:
# cipher related # cipher related
self._cipher_info = CipherInfo(False) self._cipher_info = CipherInfo(False)
if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1": if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1":
header_bytes = loader.load_to_bytes_array(file, offset=0, count=CipherInfo.HEADER_SIZE).tobytes() header_bytes = loader.load_to_bytes(file, offset=0, count=CipherInfo.HEADER_SIZE)
self._cipher_info = CipherInfo(True, header_bytes) self._cipher_info = CipherInfo(True, header_bytes)
if self._cipher_info.use_header: if self._cipher_info.use_header:
...@@ -114,18 +114,16 @@ class SafetensorsFile: ...@@ -114,18 +114,16 @@ class SafetensorsFile:
else: else:
h_off = 0 h_off = 0
magic_number = loader.load_to_bytes_array(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0] magic_number = loader.load_to_bytes(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
if magic_number != SAFETENSORS_FILE_MAGIC_NUM: if magic_number != SAFETENSORS_FILE_MAGIC_NUM:
self._is_valid = False self._is_valid = False
return return
self._meta_size = np.frombuffer( self._meta_size = np.frombuffer(
loader.load_to_bytes_array(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64 loader.load_to_bytes(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
)[0] )[0]
meta_bytes = loader.load_to_bytes_array( meta_bytes = loader.load_to_bytes(file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info)
file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info meta_dict = json.loads(meta_bytes.decode("utf-8"))
)
meta_dict = json.loads(meta_bytes.tobytes().decode("utf-8"))
self._shared_tensor = {} self._shared_tensor = {}
self._ignored_meta = {} self._ignored_meta = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment