Commit e16439e5 authored by 胡腾's avatar 胡腾 Committed by huteng.ht
Browse files

feat(sfcs): parse sfcs confs from environ in json format

* feat(sfcs): parse sfcs confs from environ in json format
* fix(ut): init_sfcs_conf with filepath
* fix: multi sfcs conf
* fix: only pass file in saver loader init
parent 66b20b34
...@@ -86,8 +86,24 @@ def print_load_time(fs_name, tensor_size, load_times): ...@@ -86,8 +86,24 @@ def print_load_time(fs_name, tensor_size, load_times):
print(f"{fs_name:<10} {str(tensor_size):<15}", ' '.join(load_times)) print(f"{fs_name:<10} {str(tensor_size):<15}", ' '.join(load_times))
def sfcs_env():
os.environ['SFCS_FSNAME'] = 'byted-cpu-sfcs'
os.environ['SFCS_REGION'] = 'cn-beijing'
os.environ['SFCS_ACCESS_KEY'] = os.environ['CI_SFCS_AK']
os.environ['SFCS_SECRET_KEY'] = os.environ['CI_SFCS_SK']
os.environ['SFCS_AUTHENTICATION_SERVICE_NAME'] = 'cfs'
os.environ['SFCS_NS_ID'] = '18014398509481988'
os.environ['SFCS_UFS_PATH'] = 'tos://yinzq-bucket/'
os.environ['SFCS_MULTI_NIC_WHITELIST'] = 'eth0'
os.environ['SFCS_NETWORK_SEGMENT'] = '172.31.128.0/17'
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
os.environ['SFCS_LOG_SEVERITY'] = 'ERROR'
def main(): def main():
args = parse_args() args = parse_args()
if args.base_dir.startswith('sfcs://'):
sfcs_env()
load_modes = args.load_mode.split(',') load_modes = args.load_mode.split(',')
# warmup GPU otherwise the first case would be slow # warmup GPU otherwise the first case would be slow
device = torch.device(args.map_location) device = torch.device(args.map_location)
......
...@@ -30,9 +30,9 @@ import numpy as np ...@@ -30,9 +30,9 @@ import numpy as np
from veturboio.ops.cipher import CipherInfo, DataPipeClient from veturboio.ops.cipher import CipherInfo, DataPipeClient
from veturboio.ops.sfcs_utils import ( from veturboio.ops.sfcs_utils import (
SFCS_OPT_ENV_LIST, SFCS_OPT_ENV_LIST,
SFCS_PROPERTIES,
SFCS_REQ_ENV_LIST, SFCS_REQ_ENV_LIST,
credentials_helper, credentials_helper,
generate_sfcs_conf_xml,
init_sfcs_conf, init_sfcs_conf,
) )
...@@ -233,60 +233,86 @@ class TestCredentials(TestCase): ...@@ -233,60 +233,86 @@ class TestCredentials(TestCase):
def test_sfcs_conf(self): def test_sfcs_conf(self):
# case 1: a xml file already exists, do nothing # case 1: a xml file already exists, do nothing
with tempfile.NamedTemporaryFile() as sfcs_conf: sfcs_conf = os.path.join(os.getcwd(), 'base_model.xml')
os.environ['LIBCFS_CONF'] = sfcs_conf.name generate_sfcs_conf_xml(sfcs_conf, {'test': 'test'})
init_sfcs_conf() init_sfcs_conf('/base_model/tensor.pt')
self.assertFalse(credentials_helper.running) self.assertEqual(os.environ['LIBCFS_CONF'], sfcs_conf)
self.assertEqual(len(credentials_helper.threads), 0)
self.assertEqual(len(credentials_helper.running), 0)
os.remove(sfcs_conf)
for e in SFCS_REQ_ENV_LIST: for e in SFCS_REQ_ENV_LIST:
os.environ[e] = 'test-value' os.environ[e] = 'test-value'
# case 2: env SFCS_ACCESS_KEY and SFCS_SECRET_KEY and SFCS_NAMENODE_ENDPOINT_ADDRESS exists # case 2: env SFCS_ACCESS_KEY and SFCS_SECRET_KEY and SFCS_NAMENODE_ENDPOINT_ADDRESS exists
with tempfile.TemporaryDirectory() as conf_dir: os.environ['SFCS_ACCESS_KEY'] = 'AKTPODg0MzV**2ZDcxMDg'
conf_path = os.path.join(conf_dir, 'libcfs.xml') os.environ['SFCS_SECRET_KEY'] = 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ=='
os.environ['LIBCFS_CONF'] = conf_path os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
os.environ['SFCS_ACCESS_KEY'] = 'AKTPODg0MzV**2ZDcxMDg' sfcs_conf = os.path.join(os.getcwd(), 'base_model2.xml')
os.environ['SFCS_SECRET_KEY'] = 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ==' init_sfcs_conf('/base_model2/tensor.pt')
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231' self.assertEqual(os.environ['LIBCFS_CONF'], sfcs_conf)
init_sfcs_conf() self.assertEqual(len(credentials_helper.threads), 0)
self.assertEqual(SFCS_PROPERTIES['cfs.access.key'], 'AKTPODg0MzV**2ZDcxMDg') self.assertEqual(len(credentials_helper.running), 0)
self.assertEqual(SFCS_PROPERTIES['cfs.secret.key'], 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ==') self.assertTrue(os.path.exists(sfcs_conf))
self.assertEqual(SFCS_PROPERTIES['cfs.namenode.endpoint.address.test-value'], '100.67.19.231') os.remove(sfcs_conf)
self.assertFalse(credentials_helper.running)
self.assertTrue(os.path.exists(conf_path))
# case 3: use datapipe socket to get and refresh ak, sk, st and namenode_ip # case 3: use datapipe socket to get and refresh ak, sk, st and namenode_ip
DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address
with tempfile.TemporaryDirectory() as conf_dir: os.environ.pop('SFCS_ACCESS_KEY', None)
conf_path = os.path.join(conf_dir, 'libcfs.xml') os.environ.pop('SFCS_SECRET_KEY', None)
os.environ['LIBCFS_CONF'] = conf_path os.environ.pop('SFCS_NAMENODE_ENDPOINT_ADDRESS', None)
os.environ.pop('SFCS_ACCESS_KEY', None) sfcs_conf3 = os.path.join(os.getcwd(), 'base_model3.xml')
os.environ.pop('SFCS_SECRET_KEY', None) sfcs_conf4 = os.path.join(os.getcwd(), 'base_model4.xml')
os.environ.pop('SFCS_NAMENODE_ENDPOINT_ADDRESS', None) init_sfcs_conf('/base_model3/tensor.pt')
SFCS_PROPERTIES.pop('cfs.access.key') init_sfcs_conf('/base_model4/tensor.pt')
SFCS_PROPERTIES.pop('cfs.secret.key') self.assertTrue('base_model3' in credentials_helper.threads)
SFCS_PROPERTIES.pop('cfs.namenode.endpoint.address.test-value') self.assertTrue('base_model4' in credentials_helper.threads)
init_sfcs_conf() self.assertTrue(credentials_helper.running['base_model3'])
self.assertEqual(SFCS_PROPERTIES['cfs.access.key'], 'AKTPODg0MzV**2ZDcxMDg') self.assertTrue(credentials_helper.running['base_model4'])
self.assertEqual(SFCS_PROPERTIES['cfs.secret.key'], 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ==') self.assertTrue(os.path.exists(sfcs_conf3))
self.assertEqual(SFCS_PROPERTIES['cfs.namenode.endpoint.address.test-value'], '100.67.19.231') self.assertTrue(os.path.exists(sfcs_conf4))
self.assertEqual(SFCS_PROPERTIES['cfs.security.token'], 'STSeyJBY2NvdW50SW**kXXXXXXX') os.remove(sfcs_conf3)
self.assertTrue(credentials_helper.running) os.remove(sfcs_conf4)
self.assertTrue(os.path.exists(conf_path)) sleep(3)
t1 = credentials_helper.current_time self.assertTrue(os.path.exists(sfcs_conf3))
sleep(3) self.assertTrue(os.path.exists(sfcs_conf4))
t2 = credentials_helper.current_time print(credentials_helper.threads)
self.assertTrue(t1 < t2)
credentials_helper.stop() def test_sfcs_conf_json(self):
for e in SFCS_REQ_ENV_LIST:
os.environ[e] = 'test-value'
os.environ['SFCS_FSNAME'] = json.dumps({'base_model1': 'test-value1', 'base_model2': 'test-value2'})
os.environ['SFCS_NS_ID'] = json.dumps({'base_model1': 'test-value1', 'base_model2': 'test-value2'})
os.environ['SFCS_UFS_PATH'] = json.dumps({'base_model1': 'test-value1', 'base_model2': 'test-value2'})
DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address
os.environ.pop('SFCS_ACCESS_KEY', None)
os.environ.pop('SFCS_SECRET_KEY', None)
os.environ.pop('SFCS_NAMENODE_ENDPOINT_ADDRESS', None)
sfcs_conf1 = os.path.join(os.getcwd(), 'base_model1.xml')
sfcs_conf2 = os.path.join(os.getcwd(), 'base_model2.xml')
init_sfcs_conf('/base_model1/tensor.pt')
init_sfcs_conf('/base_model2/tensor.pt')
self.assertTrue('base_model1' in credentials_helper.threads)
self.assertTrue('base_model2' in credentials_helper.threads)
self.assertTrue(credentials_helper.running['base_model1'])
self.assertTrue(credentials_helper.running['base_model2'])
self.assertTrue(os.path.exists(sfcs_conf1))
self.assertTrue(os.path.exists(sfcs_conf2))
os.remove(sfcs_conf1)
os.remove(sfcs_conf2)
sleep(3)
self.assertTrue(os.path.exists(sfcs_conf1))
self.assertTrue(os.path.exists(sfcs_conf2))
print(credentials_helper.threads)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
credentials_helper.stop()
os.environ.pop('LIBCFS_CONF', None) os.environ.pop('LIBCFS_CONF', None)
for e in SFCS_REQ_ENV_LIST: for e in SFCS_REQ_ENV_LIST:
os.environ.pop(e, None) os.environ.pop(e, None)
for e in SFCS_OPT_ENV_LIST: for e in SFCS_OPT_ENV_LIST:
os.environ.pop(e, None) os.environ.pop(e, None)
SFCS_PROPERTIES.pop('cfs.security.token', None)
cls.server.shutdown() cls.server.shutdown()
cls.server.server_close() cls.server.server_close()
cls.thread.join() cls.thread.join()
......
...@@ -29,10 +29,6 @@ import veturboio.ops.sfcs_utils as sfcs_utils ...@@ -29,10 +29,6 @@ import veturboio.ops.sfcs_utils as sfcs_utils
def init_sfcs_env(): def init_sfcs_env():
sfcs_conf = os.getcwd() + '/libcfs.xml'
if os.path.exists(sfcs_conf):
os.remove(sfcs_conf)
os.environ['SFCS_FSNAME'] = 'byted-cpu-sfcs' os.environ['SFCS_FSNAME'] = 'byted-cpu-sfcs'
os.environ['SFCS_REGION'] = 'cn-beijing' os.environ['SFCS_REGION'] = 'cn-beijing'
os.environ['SFCS_ACCESS_KEY'] = os.environ['CI_SFCS_AK'] os.environ['SFCS_ACCESS_KEY'] = os.environ['CI_SFCS_AK']
...@@ -45,8 +41,6 @@ def init_sfcs_env(): ...@@ -45,8 +41,6 @@ def init_sfcs_env():
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231' os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
os.environ['SFCS_LOG_SEVERITY'] = 'ERROR' os.environ['SFCS_LOG_SEVERITY'] = 'ERROR'
sfcs_utils.init_sfcs_conf()
class TestSFCS(TestCase): class TestSFCS(TestCase):
@classmethod @classmethod
...@@ -57,6 +51,12 @@ class TestSFCS(TestCase): ...@@ -57,6 +51,12 @@ class TestSFCS(TestCase):
filepath = "/data.bin" filepath = "/data.bin"
filesize = 1024 * 1024 filesize = 1024 * 1024
first_path = os.path.abspath(filepath).split("/")[1]
sfcs_conf = os.path.join(os.getcwd(), first_path + '.xml')
if os.path.exists(sfcs_conf):
os.remove(sfcs_conf)
sfcs_utils.init_sfcs_conf(filepath)
sfcs_utils.sfcs_delete_file(filepath) sfcs_utils.sfcs_delete_file(filepath)
arr_0 = np.empty([filesize], dtype=np.byte) arr_0 = np.empty([filesize], dtype=np.byte)
......
...@@ -77,16 +77,18 @@ def load( ...@@ -77,16 +77,18 @@ def load(
use_sfcs_sdk, file = is_sfcs_path(file) use_sfcs_sdk, file = is_sfcs_path(file)
if enable_fast_mode == False: if enable_fast_mode == False:
loader = PosixLoader() loader = PosixLoader(file)
elif use_sfcs_sdk: elif use_sfcs_sdk:
loader = SfcsClientLoader( loader = SfcsClientLoader(
helper, helper=helper,
file=file,
num_thread=num_thread, num_thread=num_thread,
use_pinmem=use_pinmem, use_pinmem=use_pinmem,
use_direct_io=use_direct_io, use_direct_io=use_direct_io,
) )
else: else:
loader = FasterPosixLoader( loader = FasterPosixLoader(
file,
helper, helper,
num_thread=num_thread, num_thread=num_thread,
use_pinmem=use_pinmem, use_pinmem=use_pinmem,
...@@ -126,14 +128,14 @@ def save_file( ...@@ -126,14 +128,14 @@ def save_file(
""" """
use_sfcs_sdk, file = is_sfcs_path(file) use_sfcs_sdk, file = is_sfcs_path(file)
if use_sfcs_sdk: if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher) saver = SfcsClientSaver(file=file, use_cipher=use_cipher)
else: else:
saver = PosixSaver(use_cipher=use_cipher) saver = PosixSaver(file=file, use_cipher=use_cipher)
# TODO: there are some bugs while state_dict is loaded from veturboio # TODO: there are some bugs while state_dict is loaded from veturboio
if not force_save_shared_tensor: if not force_save_shared_tensor:
try: try:
saver.save_file(state_dict, file, metadata=metadata) saver.save_file(state_dict, metadata=metadata)
except ValueError as e: except ValueError as e:
msg = str(e) msg = str(e)
raise ValueError(msg) raise ValueError(msg)
...@@ -154,7 +156,7 @@ def save_file( ...@@ -154,7 +156,7 @@ def save_file(
if force_contiguous: if force_contiguous:
state_dict = {k: v.contiguous() for k, v in state_dict.items()} state_dict = {k: v.contiguous() for k, v in state_dict.items()}
return saver.save_file(state_dict, file, metadata=metadata) return saver.save_file(state_dict, metadata=metadata)
def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[bool] = False) -> None: def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[bool] = False) -> None:
...@@ -177,11 +179,11 @@ def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[boo ...@@ -177,11 +179,11 @@ def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[boo
use_sfcs_sdk, file = is_sfcs_path(file) use_sfcs_sdk, file = is_sfcs_path(file)
if use_sfcs_sdk: if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher) saver = SfcsClientSaver(file=file, use_cipher=use_cipher)
else: else:
saver = PosixSaver(use_cipher=use_cipher) saver = PosixSaver(file=file, use_cipher=use_cipher)
return saver.save_model(model, file) return saver.save_model(model)
def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Optional[bool] = False) -> None: def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Optional[bool] = False) -> None:
...@@ -203,8 +205,8 @@ def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Op ...@@ -203,8 +205,8 @@ def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Op
""" """
use_sfcs_sdk, file = is_sfcs_path(file) use_sfcs_sdk, file = is_sfcs_path(file)
if use_sfcs_sdk: if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher) saver = SfcsClientSaver(file=file, use_cipher=use_cipher)
else: else:
saver = PosixSaver(use_cipher=use_cipher) saver = PosixSaver(file=file, use_cipher=use_cipher)
return saver.save_pt(state_dict, file) return saver.save_pt(state_dict)
...@@ -34,9 +34,7 @@ class BaseLoader: ...@@ -34,9 +34,7 @@ class BaseLoader:
def __init__(self, method: str) -> None: def __init__(self, method: str) -> None:
self.method = method self.method = method
def load_to_bytes( def load_to_bytes(self, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)) -> bytes:
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> bytes:
raise NotImplementedError raise NotImplementedError
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]: def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
...@@ -65,13 +63,12 @@ class BaseLoader: ...@@ -65,13 +63,12 @@ class BaseLoader:
class PosixLoader(BaseLoader): class PosixLoader(BaseLoader):
def __init__(self) -> None: def __init__(self, file: FILE_PATH) -> None:
super().__init__(method="posix") super().__init__(method="posix")
self.file = file
def load_to_bytes( def load_to_bytes(self, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)) -> bytes:
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False) arr = np.fromfile(self.file, dtype=np.uint8, offset=offset, count=count)
) -> bytes:
arr = np.fromfile(file, dtype=np.uint8, offset=offset, count=count)
if cipher_info.use_cipher: if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
decrypt(cipher_info, arr, arr, offset - h_off) decrypt(cipher_info, arr, arr, offset - h_off)
...@@ -107,12 +104,12 @@ class PosixLoader(BaseLoader): ...@@ -107,12 +104,12 @@ class PosixLoader(BaseLoader):
return state_dict return state_dict
def load_pt( def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False) self, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
if cipher_info.use_cipher: if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
arr = np.fromfile(file, dtype=np.uint8, offset=h_off, count=-1) arr = np.fromfile(self.file, dtype=np.uint8, offset=h_off, count=-1)
decrypt(cipher_info, arr, arr, 0) decrypt(cipher_info, arr, arr, 0)
return torch.load(io.BytesIO(arr.data), map_location=map_location) return torch.load(io.BytesIO(arr.data), map_location=map_location)
return torch.load(file, map_location=map_location) return torch.load(self.file, map_location=map_location)
...@@ -32,12 +32,13 @@ from .base_loader import PosixLoader ...@@ -32,12 +32,13 @@ from .base_loader import PosixLoader
class FasterPosixLoader(PosixLoader): class FasterPosixLoader(PosixLoader):
def __init__( def __init__(
self, self,
file: FILE_PATH,
helper: IOHelper, helper: IOHelper,
num_thread: int = 32, num_thread: int = 32,
use_pinmem: bool = False, use_pinmem: bool = False,
use_direct_io: bool = False, use_direct_io: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__(file)
self.helper = helper self.helper = helper
self.num_thread = num_thread self.num_thread = num_thread
self.use_pinmem = use_pinmem self.use_pinmem = use_pinmem
...@@ -72,12 +73,12 @@ class FasterPosixLoader(PosixLoader): ...@@ -72,12 +73,12 @@ class FasterPosixLoader(PosixLoader):
return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file) return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file)
def load_pt( def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False) self, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
if cipher_info.use_cipher: if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
arr = np.fromfile(file, dtype=np.uint8, offset=h_off, count=-1) arr = np.fromfile(self.file, dtype=np.uint8, offset=h_off, count=-1)
decrypt(cipher_info, arr, arr, 0) decrypt(cipher_info, arr, arr, 0)
return torch.load(io.BytesIO(arr.data), map_location=map_location) return torch.load(io.BytesIO(arr.data), map_location=map_location)
return torch.load(file, map_location=map_location) return torch.load(self.file, map_location=map_location)
...@@ -33,6 +33,7 @@ from veturboio.types import FILE_PATH ...@@ -33,6 +33,7 @@ from veturboio.types import FILE_PATH
class SfcsClientLoader(BaseLoader): class SfcsClientLoader(BaseLoader):
def __init__( def __init__(
self, self,
file: FILE_PATH,
helper: IOHelper, helper: IOHelper,
num_thread: int = 32, num_thread: int = 32,
use_pinmem: bool = False, use_pinmem: bool = False,
...@@ -40,24 +41,23 @@ class SfcsClientLoader(BaseLoader): ...@@ -40,24 +41,23 @@ class SfcsClientLoader(BaseLoader):
) -> None: ) -> None:
super().__init__(method="client") super().__init__(method="client")
self.file = file
self.helper = helper self.helper = helper
self.num_thread = num_thread self.num_thread = num_thread
self.use_pinmem = use_pinmem self.use_pinmem = use_pinmem
self.use_direct_io = use_direct_io self.use_direct_io = use_direct_io
init_sfcs_conf() init_sfcs_conf(file)
def load_to_bytes( def load_to_bytes(self, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)) -> bytes:
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False) file_size = sfcs_get_file_size(self.file)
) -> bytes:
file_size = sfcs_get_file_size(file)
if offset + count > file_size: if offset + count > file_size:
count = file_size - offset count = file_size - offset
file_bytes = bytes(count) file_bytes = bytes(count)
candidate = np.frombuffer(file_bytes, dtype=np.byte) candidate = np.frombuffer(file_bytes, dtype=np.byte)
sfcs_read_file( sfcs_read_file(
file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info self.file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info
) )
return file_bytes return file_bytes
...@@ -89,9 +89,9 @@ class SfcsClientLoader(BaseLoader): ...@@ -89,9 +89,9 @@ class SfcsClientLoader(BaseLoader):
return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file) return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file)
def load_pt( def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False) self, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
file_size = sfcs_get_file_size(file) file_size = sfcs_get_file_size(self.file)
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
file_bytes = self.load_to_bytes(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info) file_bytes = self.load_to_bytes(offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
return torch.load(BytesIO(file_bytes), map_location=map_location) return torch.load(BytesIO(file_bytes), map_location=map_location)
...@@ -14,18 +14,20 @@ See the License for the specific language governing permissions and ...@@ -14,18 +14,20 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
''' '''
import json
import os import os
import shutil import shutil
import tempfile import tempfile
import threading import threading
import xml.dom.minidom import xml.dom.minidom
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Optional, Tuple
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from veturboio.ops.cipher import CipherInfo, DataPipeClient from veturboio.ops.cipher import CipherInfo, DataPipeClient
from veturboio.types import FILE_PATH
try: try:
from veturboio.utils.load_veturboio_ext import load_veturboio_ext from veturboio.utils.load_veturboio_ext import load_veturboio_ext
...@@ -36,7 +38,6 @@ except ImportError: ...@@ -36,7 +38,6 @@ except ImportError:
SFCSFile = None SFCSFile = None
logger.warning("veturboio_ext not found, fallback to pure python implementation") logger.warning("veturboio_ext not found, fallback to pure python implementation")
SFCS_REQ_ENV_LIST = [ SFCS_REQ_ENV_LIST = [
'SFCS_FSNAME', 'SFCS_FSNAME',
'SFCS_REGION', 'SFCS_REGION',
...@@ -54,51 +55,48 @@ SFCS_OPT_ENV_LIST = [ ...@@ -54,51 +55,48 @@ SFCS_OPT_ENV_LIST = [
'SFCS_NAMENODE_ENDPOINT_ADDRESS', 'SFCS_NAMENODE_ENDPOINT_ADDRESS',
] ]
SFCS_PROPERTIES = {
'cfs.filesystem.fs-mode': 'ACC', def default_sfcs_properties() -> dict:
'cfs.filesystem.task-id': 'sfcs', return {
'cfs.filesystem.resolve.addr.by.dns': 'false', 'cfs.filesystem.fs-mode': 'ACC',
'cfs.metrics.emitters': 'metric_server;local_prometheus', 'cfs.filesystem.task-id': 'sfcs',
'cfs.client.metadata-cache.enable': 'false', 'cfs.filesystem.resolve.addr.by.dns': 'false',
'rpc.client.channel.pool.size': '32', 'cfs.metrics.emitters': 'metric_server;local_prometheus',
'dfs.default.replica': '2', 'cfs.client.metadata-cache.enable': 'false',
'cfs.client.multi-nic.enabled': 'true', 'rpc.client.channel.pool.size': '32',
'fs.datanode.router.ignore-main-nic': 'true', 'dfs.default.replica': '2',
'cfs.datanode.router.shuffle': 'true', 'cfs.client.multi-nic.enabled': 'true',
} 'fs.datanode.router.ignore-main-nic': 'true',
'cfs.datanode.router.shuffle': 'true',
}
class CredentialsHelper: class CredentialsHelper:
def __init__(self): def __init__(self):
self.lock = threading.Lock() self.lock = threading.Lock()
self.running = False self.running = {}
# daemon thread will stop when parent thread exits # daemon thread will stop when parent thread exits
self.thread = None self.threads = {}
self.stop_flag = threading.Event()
self.client = None self.client = None
self.current_time = 0
self.expired_time = 0 def run(self, group: str, sfcs_conf_path: str) -> None:
self.ak = None if not self.running.get(group, False):
self.sk = None
self.st = None
self.name_node_ip = None
self.sfcs_conf_path = None
self.stop_flag = None
def run(self, sfcs_conf_path) -> None:
if not self.running:
with self.lock: with self.lock:
if not self.running: if not self.running.get(group, False):
self.thread = threading.Thread(target=self.refresh_loop, daemon=True) if self.client is None:
self.stop_flag = threading.Event() self.client = DataPipeClient()
self.client = DataPipeClient() init_ts = self.do_refresh(group, sfcs_conf_path)
self.sfcs_conf_path = sfcs_conf_path if not init_ts:
if not self.do_refresh(): raise RuntimeError(f'Credentials helper for {sfcs_conf_path} first fetch failed')
raise RuntimeError('Credentials helper first fetch failed') self.threads[group] = threading.Thread(
self.thread.start() target=self.refresh_loop, args=(group, sfcs_conf_path, init_ts[0], init_ts[1]), daemon=True
self.running = True )
logger.info('CredentialsHelper refresh thread strat') self.threads[group].start()
self.running[group] = True
logger.info(f'CredentialsHelper refresh thread for {sfcs_conf_path} start')
return return
logger.info('CredentialsHelper thread is already running, do nothing') logger.info(f'CredentialsHelper thread for {sfcs_conf_path} is already running, do nothing')
def stop(self): def stop(self):
self.stop_flag.set() self.stop_flag.set()
...@@ -115,77 +113,92 @@ class CredentialsHelper: ...@@ -115,77 +113,92 @@ class CredentialsHelper:
return False return False
return True return True
def refresh_loop(self) -> None: def refresh_loop(self, group: str, sfcs_conf_path: str, current_time: float, expired_time: float) -> None:
while True: while True:
now = datetime.now(tz=timezone.utc).timestamp() now = datetime.now(tz=timezone.utc).timestamp()
ts_ref = (self.current_time + self.expired_time) / 2 ts_ref = (current_time + expired_time) / 2
if now >= ts_ref: if now >= ts_ref:
if not self.do_refresh(): ts = self.do_refresh(group, sfcs_conf_path)
raise RuntimeError('Credentials helper do refresh failed') if not ts:
raise RuntimeError(f'Credentials helper do refresh at {sfcs_conf_path} failed')
current_time, expired_time = ts[0], ts[1]
else: else:
if self.stop_flag.wait(ts_ref - now): if self.stop_flag.wait(ts_ref - now):
return return
def do_refresh(self) -> bool: def do_refresh(self, group: str, sfcs_conf_path: str) -> Optional[Tuple[float, float]]:
d = self.client.get_sfcs_ak_sk_st() d = self.client.get_sfcs_ak_sk_st()
if self.is_valid_res(d): if self.is_valid_res(d):
self.name_node_ip = d['SfcsNameNodeAddress'] name_node_ip = d['SfcsNameNodeAddress']
d = d['Cred'] d = d['Cred']
self.current_time = datetime.fromisoformat(d['CurrentTime']).timestamp() current_time = datetime.fromisoformat(d['CurrentTime']).timestamp()
self.expired_time = datetime.fromisoformat(d['ExpiredTime']).timestamp() expired_time = datetime.fromisoformat(d['ExpiredTime']).timestamp()
self.ak = d['AccessKeyId'] ak = d['AccessKeyId']
self.sk = d['SecretAccessKey'] sk = d['SecretAccessKey']
self.st = d['SessionToken'] st = d['SessionToken']
# update SFCS_PROPERTIES and then write xml try:
SFCS_PROPERTIES['cfs.access.key'] = self.ak sfcs_fsname = json.loads(os.getenv('SFCS_FSNAME'))[group]
SFCS_PROPERTIES['cfs.secret.key'] = self.sk except:
SFCS_PROPERTIES['cfs.security.token'] = self.st sfcs_fsname = os.getenv('SFCS_FSNAME')
SFCS_PROPERTIES['cfs.namenode.endpoint.address.' + os.getenv('SFCS_FSNAME')] = self.name_node_ip properties = init_sfcs_properties(group)
generate_sfcs_conf_xml(self.sfcs_conf_path) properties['cfs.access.key'] = ak
logger.info('Credentials are successfully refreshed!') properties['cfs.secret.key'] = sk
return True properties['cfs.security.token'] = st
properties['cfs.namenode.endpoint.address.' + sfcs_fsname] = name_node_ip
generate_sfcs_conf_xml(sfcs_conf_path, properties)
logger.info(f'Credentials are successfully refreshed at {sfcs_conf_path}!')
return current_time, expired_time
else: else:
return False return None
credentials_helper = CredentialsHelper() credentials_helper = CredentialsHelper()
def init_sfcs_properties(): def init_sfcs_properties(group: str) -> dict:
for env in SFCS_REQ_ENV_LIST: for env in SFCS_REQ_ENV_LIST:
if os.getenv(env) is None: if os.getenv(env) is None:
raise ValueError('environ ' + env + ' not set') raise ValueError('environ ' + env + ' not set')
SFCS_PROPERTIES['dfs.default.uri'] = ( try:
'cfs://' + os.getenv('SFCS_FSNAME') + '.sfcs-' + os.getenv('SFCS_REGION') + '.ivolces.com' sfcs_fsname = json.loads(os.getenv('SFCS_FSNAME'))[group]
) sfcs_ns_id = json.loads(os.getenv('SFCS_NS_ID'))[group]
SFCS_PROPERTIES['dfs.authentication.service.name'] = os.getenv('SFCS_AUTHENTICATION_SERVICE_NAME') sfcs_ufs_path = json.loads(os.getenv('SFCS_UFS_PATH'))[group]
SFCS_PROPERTIES['cfs.filesystem.ns-id'] = os.getenv('SFCS_NS_ID') logger.info(f"parse sfcs fsname, ns_id and ufs_path from environ in JSON format")
SFCS_PROPERTIES['cfs.filesystem.ufs-path'] = os.getenv('SFCS_UFS_PATH') except:
SFCS_PROPERTIES['cfs.metrics.server.host'] = 'metricserver.cfs-' + os.getenv('SFCS_REGION') + '.ivolces.com' sfcs_fsname = os.getenv('SFCS_FSNAME')
SFCS_PROPERTIES['cfs.client.multi-nic.whitelist'] = os.getenv('SFCS_MULTI_NIC_WHITELIST') sfcs_ns_id = os.getenv('SFCS_NS_ID')
SFCS_PROPERTIES['cfs.client.network.segment'] = os.getenv('SFCS_NETWORK_SEGMENT') sfcs_ufs_path = os.getenv('SFCS_UFS_PATH')
SFCS_PROPERTIES['dfs.client.log.severity'] = os.getenv('SFCS_LOG_SEVERITY') logger.info(f"parse sfcs fsname, ns_id and ufs_path from environ in STRING format")
properties = default_sfcs_properties()
properties['dfs.default.uri'] = 'cfs://' + sfcs_fsname + '.sfcs-' + os.getenv('SFCS_REGION') + '.ivolces.com'
properties['dfs.authentication.service.name'] = os.getenv('SFCS_AUTHENTICATION_SERVICE_NAME')
properties['cfs.filesystem.ns-id'] = sfcs_ns_id
properties['cfs.filesystem.ufs-path'] = sfcs_ufs_path
properties['cfs.metrics.server.host'] = 'metricserver.cfs-' + os.getenv('SFCS_REGION') + '.ivolces.com'
properties['cfs.client.multi-nic.whitelist'] = os.getenv('SFCS_MULTI_NIC_WHITELIST')
properties['cfs.client.network.segment'] = os.getenv('SFCS_NETWORK_SEGMENT')
properties['dfs.client.log.severity'] = os.getenv('SFCS_LOG_SEVERITY')
# optional # optional
SFCS_PROPERTIES['cfs.filesystem.sync-interval'] = os.getenv('SFCS_SYNC_INTERVAL', "-1") properties['cfs.filesystem.sync-interval'] = os.getenv('SFCS_SYNC_INTERVAL', "-1")
SFCS_PROPERTIES['cfs.access.key'] = os.getenv('SFCS_ACCESS_KEY') properties['cfs.access.key'] = os.getenv('SFCS_ACCESS_KEY')
SFCS_PROPERTIES['cfs.secret.key'] = os.getenv('SFCS_SECRET_KEY') properties['cfs.secret.key'] = os.getenv('SFCS_SECRET_KEY')
SFCS_PROPERTIES['cfs.namenode.endpoint.address.' + os.getenv('SFCS_FSNAME')] = os.getenv( properties['cfs.namenode.endpoint.address.' + sfcs_fsname] = os.getenv('SFCS_NAMENODE_ENDPOINT_ADDRESS')
'SFCS_NAMENODE_ENDPOINT_ADDRESS' return properties
)
def generate_sfcs_conf_xml(sfcs_conf): def generate_sfcs_conf_xml(sfcs_conf: FILE_PATH, sfcs_properties: dict):
doc = xml.dom.minidom.Document() doc = xml.dom.minidom.Document()
configuration = doc.createElement('configuration') configuration = doc.createElement('configuration')
doc.appendChild(configuration) doc.appendChild(configuration)
for key in SFCS_PROPERTIES: for key in sfcs_properties:
property = doc.createElement('property') property = doc.createElement('property')
name = doc.createElement('name') name = doc.createElement('name')
name.appendChild(doc.createTextNode(key)) name.appendChild(doc.createTextNode(key))
value = doc.createElement('value') value = doc.createElement('value')
value.appendChild(doc.createTextNode(SFCS_PROPERTIES[key])) value.appendChild(doc.createTextNode(sfcs_properties[key]))
property.appendChild(name) property.appendChild(name)
property.appendChild(value) property.appendChild(value)
...@@ -200,29 +213,33 @@ def generate_sfcs_conf_xml(sfcs_conf): ...@@ -200,29 +213,33 @@ def generate_sfcs_conf_xml(sfcs_conf):
shutil.move(tmp_conf.name, sfcs_conf) shutil.move(tmp_conf.name, sfcs_conf)
def init_sfcs_conf(): def sfcs_conf_group(file: FILE_PATH) -> str:
if not os.getenv('LIBCFS_CONF'): return os.path.abspath(file).split("/")[1]
logger.warning('environ LIBCFS_CONF not set, set it to ' + os.getcwd() + '/libcfs.xml')
os.environ['LIBCFS_CONF'] = os.getcwd() + '/libcfs.xml'
def init_sfcs_conf(file: FILE_PATH):
group = sfcs_conf_group(file)
sfcs_conf = os.path.join(os.getcwd(), group + '.xml')
os.environ['LIBCFS_CONF'] = sfcs_conf
logger.info(f'environ LIBCFS_CONF set to {sfcs_conf}')
sfcs_conf = os.getenv('LIBCFS_CONF') if os.path.isfile(sfcs_conf):
if os.path.exists(sfcs_conf):
# case 1: a xml file already exists, do nothing # case 1: a xml file already exists, do nothing
logger.warning('LIBCFS_CONF file exists') logger.info('LIBCFS_CONF file exists')
else: else:
init_sfcs_properties()
if ( if (
os.getenv('SFCS_ACCESS_KEY') os.getenv('SFCS_ACCESS_KEY')
and os.getenv('SFCS_SECRET_KEY') and os.getenv('SFCS_SECRET_KEY')
and os.getenv('SFCS_NAMENODE_ENDPOINT_ADDRESS') and os.getenv('SFCS_NAMENODE_ENDPOINT_ADDRESS')
): ):
# case 2: env SFCS_ACCESS_KEY, SFCS_SECRET_KEY and SFCS_NAMENODE_ENDPOINT_ADDRESS exist # case 2: env SFCS_ACCESS_KEY, SFCS_SECRET_KEY and SFCS_NAMENODE_ENDPOINT_ADDRESS exist
logger.warning('Use aksk and namenode_ip in env to generate sfcs config') logger.info('Use aksk and namenode_ip in env to generate sfcs config')
generate_sfcs_conf_xml(sfcs_conf) properties = init_sfcs_properties(group)
generate_sfcs_conf_xml(sfcs_conf, properties)
else: else:
# case 3: use datapipe socket to get and refresh ak, sk, and st # case 3: use datapipe socket to get and refresh ak, sk, and st
logger.warning('Use credentials helper to generate and update sfcs config') logger.info('Use credentials helper to generate and update sfcs config')
credentials_helper.run(sfcs_conf) credentials_helper.run(group, sfcs_conf)
def sfcs_get_file_size(file_path: str) -> int: def sfcs_get_file_size(file_path: str) -> int:
......
...@@ -106,7 +106,7 @@ class SafetensorsFile: ...@@ -106,7 +106,7 @@ class SafetensorsFile:
# cipher related # cipher related
self._cipher_info = CipherInfo(False) self._cipher_info = CipherInfo(False)
if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1": if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1":
header_bytes = loader.load_to_bytes(file, offset=0, count=CipherInfo.HEADER_SIZE) header_bytes = loader.load_to_bytes(offset=0, count=CipherInfo.HEADER_SIZE)
self._cipher_info = CipherInfo(True, header_bytes) self._cipher_info = CipherInfo(True, header_bytes)
if self._cipher_info.use_header: if self._cipher_info.use_header:
...@@ -114,15 +114,15 @@ class SafetensorsFile: ...@@ -114,15 +114,15 @@ class SafetensorsFile:
else: else:
h_off = 0 h_off = 0
magic_number = loader.load_to_bytes(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0] magic_number = loader.load_to_bytes(offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
if magic_number != SAFETENSORS_FILE_MAGIC_NUM: if magic_number != SAFETENSORS_FILE_MAGIC_NUM:
self._is_valid = False self._is_valid = False
return return
self._meta_size = np.frombuffer( self._meta_size = np.frombuffer(
loader.load_to_bytes(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64 loader.load_to_bytes(offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
)[0] )[0]
meta_bytes = loader.load_to_bytes(file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info) meta_bytes = loader.load_to_bytes(offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info)
meta_dict = json.loads(meta_bytes.decode("utf-8")) meta_dict = json.loads(meta_bytes.decode("utf-8"))
self._shared_tensor = {} self._shared_tensor = {}
...@@ -208,6 +208,6 @@ class SafetensorsFile: ...@@ -208,6 +208,6 @@ class SafetensorsFile:
def load(self, map_location: str = "cpu") -> Dict[str, torch.Tensor]: def load(self, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
if not self._is_valid: if not self._is_valid:
return self._loader.load_pt(self.file, map_location, self._cipher_info) return self._loader.load_pt(map_location, self._cipher_info)
else: else:
return self._loader.load_safetensors(self, map_location) return self._loader.load_safetensors(self, map_location)
...@@ -39,8 +39,9 @@ class BaseSaver: ...@@ -39,8 +39,9 @@ class BaseSaver:
class PosixSaver(BaseSaver): class PosixSaver(BaseSaver):
def __init__(self, use_cipher: bool = False) -> None: def __init__(self, file: FILE_PATH, use_cipher: bool = False) -> None:
super().__init__(method="posix") super().__init__(method="posix")
self.file = file
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1" use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1" use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
if use_header: if use_header:
...@@ -48,7 +49,7 @@ class PosixSaver(BaseSaver): ...@@ -48,7 +49,7 @@ class PosixSaver(BaseSaver):
else: else:
self.cipher_info = CipherInfo(use_cipher) self.cipher_info = CipherInfo(use_cipher)
def save_file(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH, metadata: Dict[str, str] = None) -> None: def save_file(self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None) -> None:
if self.cipher_info.use_cipher: if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name tmp_file_path = tmpfile.name
...@@ -56,15 +57,15 @@ class PosixSaver(BaseSaver): ...@@ -56,15 +57,15 @@ class PosixSaver(BaseSaver):
tmp_file_size = os.path.getsize(tmp_file_path) tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size) tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off) file_bytes = np.memmap(self.file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0) encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off: if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8) file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush() file_bytes.flush()
else: else:
safetenors_save_file(state_dict, file, metadata=metadata) safetenors_save_file(state_dict, self.file, metadata=metadata)
def save_model(self, model: torch.nn.Module, file: FILE_PATH) -> None: def save_model(self, model: torch.nn.Module) -> None:
if self.cipher_info.use_cipher: if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name tmp_file_path = tmpfile.name
...@@ -72,15 +73,15 @@ class PosixSaver(BaseSaver): ...@@ -72,15 +73,15 @@ class PosixSaver(BaseSaver):
tmp_file_size = os.path.getsize(tmp_file_path) tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size) tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off) file_bytes = np.memmap(self.file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0) encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off: if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8) file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush() file_bytes.flush()
else: else:
safetensors_save_model(model, file) safetensors_save_model(model, self.file)
def save_pt(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH) -> None: def save_pt(self, state_dict: Dict[str, torch.Tensor]) -> None:
if self.cipher_info.use_cipher: if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name tmp_file_path = tmpfile.name
...@@ -88,10 +89,10 @@ class PosixSaver(BaseSaver): ...@@ -88,10 +89,10 @@ class PosixSaver(BaseSaver):
tmp_file_size = os.path.getsize(tmp_file_path) tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size) tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0 h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off) file_bytes = np.memmap(self.file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0) encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off: if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8) file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush() file_bytes.flush()
else: else:
torch.save(state_dict, file) torch.save(state_dict, self.file)
...@@ -30,10 +30,11 @@ from veturboio.types import FILE_PATH ...@@ -30,10 +30,11 @@ from veturboio.types import FILE_PATH
class SfcsClientSaver(BaseSaver): class SfcsClientSaver(BaseSaver):
def __init__(self, use_cipher: bool = False) -> None: def __init__(self, file: FILE_PATH, use_cipher: bool = False) -> None:
super().__init__(method="client") super().__init__(method="client")
init_sfcs_conf() self.file = file
init_sfcs_conf(file)
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1" use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1" use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
...@@ -42,7 +43,7 @@ class SfcsClientSaver(BaseSaver): ...@@ -42,7 +43,7 @@ class SfcsClientSaver(BaseSaver):
else: else:
self.cipher_info = CipherInfo(use_cipher) self.cipher_info = CipherInfo(use_cipher)
def save_file(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH, metadata: Dict[str, str] = None) -> None: def save_file(self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name file_path = tmpfile.name
safetenors_save_file(state_dict, file_path, metadata=metadata) safetenors_save_file(state_dict, file_path, metadata=metadata)
...@@ -55,9 +56,9 @@ class SfcsClientSaver(BaseSaver): ...@@ -55,9 +56,9 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size) file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else: else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size) file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info) sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info)
def save_model(self, model: torch.nn.Module, file: FILE_PATH) -> None: def save_model(self, model: torch.nn.Module) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name file_path = tmpfile.name
safetensors_save_model(model, file_path) safetensors_save_model(model, file_path)
...@@ -70,9 +71,9 @@ class SfcsClientSaver(BaseSaver): ...@@ -70,9 +71,9 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size) file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else: else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size) file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info) sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info)
def save_pt(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH) -> None: def save_pt(self, state_dict: Dict[str, torch.Tensor]) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile: with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name file_path = tmpfile.name
torch.save(state_dict, file_path) torch.save(state_dict, file_path)
...@@ -85,4 +86,4 @@ class SfcsClientSaver(BaseSaver): ...@@ -85,4 +86,4 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size) file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else: else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size) file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info) sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment