Commit 0ed05516 authored by huteng.ht's avatar huteng.ht
Browse files

feat: upgrade to sdk v1 latest version



* 70b2701 on master
Signed-off-by: default avatarhuteng.ht <huteng.ht@bytedance.com>
parent 61d052cb
Pipeline #2963 canceled with stages
...@@ -24,39 +24,67 @@ from safetensors.torch import save_file as safetenors_save_file ...@@ -24,39 +24,67 @@ from safetensors.torch import save_file as safetenors_save_file
from safetensors.torch import save_model as safetensors_save_model from safetensors.torch import save_model as safetensors_save_model
from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header from veturboio.ops.cipher import CipherInfo, CipherMode, create_cipher_with_header
from veturboio.ops.sfcs_utils import init_sfcs_conf, sfcs_get_file_size, sfcs_write_file from veturboio.ops.io_utils import IOHelper
from veturboio.ops.io_utils import save_file as fast_save_file
from veturboio.ops.sfcs_utils import (
init_sfcs_conf,
path_mapper,
sfcs_delete_file,
sfcs_write_file,
sfcs_write_file_in_parallel,
)
from veturboio.saver.base_saver import BaseSaver from veturboio.saver.base_saver import BaseSaver
from veturboio.types import FILE_PATH from veturboio.types import FILE_PATH
class SfcsClientSaver(BaseSaver): class SfcsClientSaver(BaseSaver):
def __init__(self, file: FILE_PATH, use_cipher: bool = False) -> None: def __init__(
self,
file: FILE_PATH,
helper: IOHelper = None,
use_cipher: bool = False,
) -> None:
super().__init__(method="client") super().__init__(method="client")
self.file = file self.file = file
init_sfcs_conf(file) self.helper = helper
mount_path = init_sfcs_conf(file)
self.sfcs_valid_path = path_mapper(self.file, mount_path)
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1" use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1" use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
if use_header: if use_header:
self.cipher_info = create_cipher_with_header(CipherMode.CTR_128) self.cipher_info = create_cipher_with_header(CipherMode.CTR_128, os.path.abspath(self.file))
else: else:
self.cipher_info = CipherInfo(use_cipher) self.cipher_info = CipherInfo(use_cipher, None, os.path.abspath(self.file))
def save_file(self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None) -> None: def save_file(
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None, enable_fast_mode: bool = False
file_path = tmpfile.name ) -> None:
safetenors_save_file(state_dict, file_path, metadata=metadata) if enable_fast_mode:
fast_save_file(
file_size = os.path.getsize(file_path) state_dict,
if self.cipher_info.use_header: self.sfcs_valid_path,
h_off = CipherInfo.HEADER_SIZE helper=self.helper,
file_bytes = np.empty(file_size + h_off, dtype=np.byte) metadata=metadata,
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.byte) cipher_info=self.cipher_info,
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size) use_sfcs_sdk=True,
else: )
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size) else:
sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info) with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name
safetenors_save_file(state_dict, file_path, metadata=metadata)
file_size = os.path.getsize(file_path)
if self.cipher_info.use_header:
h_off = CipherInfo.HEADER_SIZE
file_bytes = np.empty(file_size + h_off, dtype=np.byte)
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.byte)
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(self.sfcs_valid_path, file_bytes, len(file_bytes), self.cipher_info)
def save_model(self, model: torch.nn.Module) -> None: def save_model(self, model: torch.nn.Module) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
...@@ -71,7 +99,7 @@ class SfcsClientSaver(BaseSaver): ...@@ -71,7 +99,7 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size) file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else: else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size) file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info) sfcs_write_file(self.sfcs_valid_path, file_bytes, len(file_bytes), self.cipher_info)
def save_pt(self, state_dict: Dict[str, torch.Tensor]) -> None: def save_pt(self, state_dict: Dict[str, torch.Tensor]) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
...@@ -86,4 +114,4 @@ class SfcsClientSaver(BaseSaver): ...@@ -86,4 +114,4 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size) file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else: else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size) file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info) sfcs_write_file(self.sfcs_valid_path, file_bytes, len(file_bytes), self.cipher_info)
...@@ -18,8 +18,8 @@ import os ...@@ -18,8 +18,8 @@ import os
from loguru import logger from loguru import logger
LIBCFS_DEFAULT_URL = "https://veturbo-cn-beijing.tos-cn-beijing.volces.com/veturboio/libcfs/libcfs.so" LIBCFS_DEFAULT_URL = "https://veturbo-cn-beijing.tos-cn-beijing.volces.com/veturboio/libcfs/libcloudfs.so"
LIBCFS_DEFAULT_PATH = "/usr/lib/libcfs.so" LIBCFS_DEFAULT_PATH = "/usr/lib/libcloudfs.so"
def load_libcfs(): def load_libcfs():
...@@ -29,7 +29,7 @@ def load_libcfs(): ...@@ -29,7 +29,7 @@ def load_libcfs():
import requests import requests
libcfs_url = os.getenv("LIBCFS_URL", LIBCFS_DEFAULT_URL) libcfs_url = os.getenv("LIBCFS_URL", LIBCFS_DEFAULT_URL)
logger.info(f"download libcfs.so from {libcfs_url}, save to {libcfs_path}") logger.info(f"download libcloudfs.so from {libcfs_url}, save to {libcfs_path}")
r = requests.get(libcfs_url, timeout=60) r = requests.get(libcfs_url, timeout=60)
with open(libcfs_path, 'wb') as f: with open(libcfs_path, 'wb') as f:
f.write(r.content) f.write(r.content)
......
...@@ -14,4 +14,4 @@ See the License for the specific language governing permissions and ...@@ -14,4 +14,4 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
''' '''
__version__ = "0.1.3" __version__ = "0.1.3rc4"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment