Commit e16439e5 authored by 胡腾's avatar 胡腾 Committed by huteng.ht
Browse files

feat(sfcs): parse sfcs confs from environ in json format

* feat(sfcs): parse sfcs confs from environ in json format
* fix(ut): init_sfcs_conf with filepath
* fix: multi sfcs conf
* fix: only pass file in saver loader init
parent 66b20b34
......@@ -86,8 +86,24 @@ def print_load_time(fs_name, tensor_size, load_times):
print(f"{fs_name:<10} {str(tensor_size):<15}", ' '.join(load_times))
def sfcs_env():
os.environ['SFCS_FSNAME'] = 'byted-cpu-sfcs'
os.environ['SFCS_REGION'] = 'cn-beijing'
os.environ['SFCS_ACCESS_KEY'] = os.environ['CI_SFCS_AK']
os.environ['SFCS_SECRET_KEY'] = os.environ['CI_SFCS_SK']
os.environ['SFCS_AUTHENTICATION_SERVICE_NAME'] = 'cfs'
os.environ['SFCS_NS_ID'] = '18014398509481988'
os.environ['SFCS_UFS_PATH'] = 'tos://yinzq-bucket/'
os.environ['SFCS_MULTI_NIC_WHITELIST'] = 'eth0'
os.environ['SFCS_NETWORK_SEGMENT'] = '172.31.128.0/17'
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
os.environ['SFCS_LOG_SEVERITY'] = 'ERROR'
def main():
args = parse_args()
if args.base_dir.startswith('sfcs://'):
sfcs_env()
load_modes = args.load_mode.split(',')
# warmup GPU otherwise the first case would be slow
device = torch.device(args.map_location)
......
......@@ -30,9 +30,9 @@ import numpy as np
from veturboio.ops.cipher import CipherInfo, DataPipeClient
from veturboio.ops.sfcs_utils import (
SFCS_OPT_ENV_LIST,
SFCS_PROPERTIES,
SFCS_REQ_ENV_LIST,
credentials_helper,
generate_sfcs_conf_xml,
init_sfcs_conf,
)
......@@ -233,60 +233,86 @@ class TestCredentials(TestCase):
def test_sfcs_conf(self):
# case 1: a xml file already exists, do nothing
with tempfile.NamedTemporaryFile() as sfcs_conf:
os.environ['LIBCFS_CONF'] = sfcs_conf.name
init_sfcs_conf()
self.assertFalse(credentials_helper.running)
sfcs_conf = os.path.join(os.getcwd(), 'base_model.xml')
generate_sfcs_conf_xml(sfcs_conf, {'test': 'test'})
init_sfcs_conf('/base_model/tensor.pt')
self.assertEqual(os.environ['LIBCFS_CONF'], sfcs_conf)
self.assertEqual(len(credentials_helper.threads), 0)
self.assertEqual(len(credentials_helper.running), 0)
os.remove(sfcs_conf)
for e in SFCS_REQ_ENV_LIST:
os.environ[e] = 'test-value'
# case 2: env SFCS_ACCESS_KEY and SFCS_SECRET_KEY and SFCS_NAMENODE_ENDPOINT_ADDRESS exists
with tempfile.TemporaryDirectory() as conf_dir:
conf_path = os.path.join(conf_dir, 'libcfs.xml')
os.environ['LIBCFS_CONF'] = conf_path
os.environ['SFCS_ACCESS_KEY'] = 'AKTPODg0MzV**2ZDcxMDg'
os.environ['SFCS_SECRET_KEY'] = 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ=='
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
init_sfcs_conf()
self.assertEqual(SFCS_PROPERTIES['cfs.access.key'], 'AKTPODg0MzV**2ZDcxMDg')
self.assertEqual(SFCS_PROPERTIES['cfs.secret.key'], 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ==')
self.assertEqual(SFCS_PROPERTIES['cfs.namenode.endpoint.address.test-value'], '100.67.19.231')
self.assertFalse(credentials_helper.running)
self.assertTrue(os.path.exists(conf_path))
os.environ['SFCS_ACCESS_KEY'] = 'AKTPODg0MzV**2ZDcxMDg'
os.environ['SFCS_SECRET_KEY'] = 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ=='
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
sfcs_conf = os.path.join(os.getcwd(), 'base_model2.xml')
init_sfcs_conf('/base_model2/tensor.pt')
self.assertEqual(os.environ['LIBCFS_CONF'], sfcs_conf)
self.assertEqual(len(credentials_helper.threads), 0)
self.assertEqual(len(credentials_helper.running), 0)
self.assertTrue(os.path.exists(sfcs_conf))
os.remove(sfcs_conf)
# case 3: use datapipe socket to get and refresh ak, sk, st and namenode_ip
DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address
with tempfile.TemporaryDirectory() as conf_dir:
conf_path = os.path.join(conf_dir, 'libcfs.xml')
os.environ['LIBCFS_CONF'] = conf_path
os.environ.pop('SFCS_ACCESS_KEY', None)
os.environ.pop('SFCS_SECRET_KEY', None)
os.environ.pop('SFCS_NAMENODE_ENDPOINT_ADDRESS', None)
SFCS_PROPERTIES.pop('cfs.access.key')
SFCS_PROPERTIES.pop('cfs.secret.key')
SFCS_PROPERTIES.pop('cfs.namenode.endpoint.address.test-value')
init_sfcs_conf()
self.assertEqual(SFCS_PROPERTIES['cfs.access.key'], 'AKTPODg0MzV**2ZDcxMDg')
self.assertEqual(SFCS_PROPERTIES['cfs.secret.key'], 'TVRNNVlqRmxPR1**mRoTkdWbE1ESQ==')
self.assertEqual(SFCS_PROPERTIES['cfs.namenode.endpoint.address.test-value'], '100.67.19.231')
self.assertEqual(SFCS_PROPERTIES['cfs.security.token'], 'STSeyJBY2NvdW50SW**kXXXXXXX')
self.assertTrue(credentials_helper.running)
self.assertTrue(os.path.exists(conf_path))
t1 = credentials_helper.current_time
sleep(3)
t2 = credentials_helper.current_time
self.assertTrue(t1 < t2)
credentials_helper.stop()
os.environ.pop('SFCS_ACCESS_KEY', None)
os.environ.pop('SFCS_SECRET_KEY', None)
os.environ.pop('SFCS_NAMENODE_ENDPOINT_ADDRESS', None)
sfcs_conf3 = os.path.join(os.getcwd(), 'base_model3.xml')
sfcs_conf4 = os.path.join(os.getcwd(), 'base_model4.xml')
init_sfcs_conf('/base_model3/tensor.pt')
init_sfcs_conf('/base_model4/tensor.pt')
self.assertTrue('base_model3' in credentials_helper.threads)
self.assertTrue('base_model4' in credentials_helper.threads)
self.assertTrue(credentials_helper.running['base_model3'])
self.assertTrue(credentials_helper.running['base_model4'])
self.assertTrue(os.path.exists(sfcs_conf3))
self.assertTrue(os.path.exists(sfcs_conf4))
os.remove(sfcs_conf3)
os.remove(sfcs_conf4)
sleep(3)
self.assertTrue(os.path.exists(sfcs_conf3))
self.assertTrue(os.path.exists(sfcs_conf4))
print(credentials_helper.threads)
def test_sfcs_conf_json(self):
for e in SFCS_REQ_ENV_LIST:
os.environ[e] = 'test-value'
os.environ['SFCS_FSNAME'] = json.dumps({'base_model1': 'test-value1', 'base_model2': 'test-value2'})
os.environ['SFCS_NS_ID'] = json.dumps({'base_model1': 'test-value1', 'base_model2': 'test-value2'})
os.environ['SFCS_UFS_PATH'] = json.dumps({'base_model1': 'test-value1', 'base_model2': 'test-value2'})
DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address
os.environ.pop('SFCS_ACCESS_KEY', None)
os.environ.pop('SFCS_SECRET_KEY', None)
os.environ.pop('SFCS_NAMENODE_ENDPOINT_ADDRESS', None)
sfcs_conf1 = os.path.join(os.getcwd(), 'base_model1.xml')
sfcs_conf2 = os.path.join(os.getcwd(), 'base_model2.xml')
init_sfcs_conf('/base_model1/tensor.pt')
init_sfcs_conf('/base_model2/tensor.pt')
self.assertTrue('base_model1' in credentials_helper.threads)
self.assertTrue('base_model2' in credentials_helper.threads)
self.assertTrue(credentials_helper.running['base_model1'])
self.assertTrue(credentials_helper.running['base_model2'])
self.assertTrue(os.path.exists(sfcs_conf1))
self.assertTrue(os.path.exists(sfcs_conf2))
os.remove(sfcs_conf1)
os.remove(sfcs_conf2)
sleep(3)
self.assertTrue(os.path.exists(sfcs_conf1))
self.assertTrue(os.path.exists(sfcs_conf2))
print(credentials_helper.threads)
@classmethod
def tearDownClass(cls):
credentials_helper.stop()
os.environ.pop('LIBCFS_CONF', None)
for e in SFCS_REQ_ENV_LIST:
os.environ.pop(e, None)
for e in SFCS_OPT_ENV_LIST:
os.environ.pop(e, None)
SFCS_PROPERTIES.pop('cfs.security.token', None)
cls.server.shutdown()
cls.server.server_close()
cls.thread.join()
......
......@@ -29,10 +29,6 @@ import veturboio.ops.sfcs_utils as sfcs_utils
def init_sfcs_env():
sfcs_conf = os.getcwd() + '/libcfs.xml'
if os.path.exists(sfcs_conf):
os.remove(sfcs_conf)
os.environ['SFCS_FSNAME'] = 'byted-cpu-sfcs'
os.environ['SFCS_REGION'] = 'cn-beijing'
os.environ['SFCS_ACCESS_KEY'] = os.environ['CI_SFCS_AK']
......@@ -45,8 +41,6 @@ def init_sfcs_env():
os.environ['SFCS_NAMENODE_ENDPOINT_ADDRESS'] = '100.67.19.231'
os.environ['SFCS_LOG_SEVERITY'] = 'ERROR'
sfcs_utils.init_sfcs_conf()
class TestSFCS(TestCase):
@classmethod
......@@ -57,6 +51,12 @@ class TestSFCS(TestCase):
filepath = "/data.bin"
filesize = 1024 * 1024
first_path = os.path.abspath(filepath).split("/")[1]
sfcs_conf = os.path.join(os.getcwd(), first_path + '.xml')
if os.path.exists(sfcs_conf):
os.remove(sfcs_conf)
sfcs_utils.init_sfcs_conf(filepath)
sfcs_utils.sfcs_delete_file(filepath)
arr_0 = np.empty([filesize], dtype=np.byte)
......
......@@ -77,16 +77,18 @@ def load(
use_sfcs_sdk, file = is_sfcs_path(file)
if enable_fast_mode == False:
loader = PosixLoader()
loader = PosixLoader(file)
elif use_sfcs_sdk:
loader = SfcsClientLoader(
helper,
helper=helper,
file=file,
num_thread=num_thread,
use_pinmem=use_pinmem,
use_direct_io=use_direct_io,
)
else:
loader = FasterPosixLoader(
file,
helper,
num_thread=num_thread,
use_pinmem=use_pinmem,
......@@ -126,14 +128,14 @@ def save_file(
"""
use_sfcs_sdk, file = is_sfcs_path(file)
if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher)
saver = SfcsClientSaver(file=file, use_cipher=use_cipher)
else:
saver = PosixSaver(use_cipher=use_cipher)
saver = PosixSaver(file=file, use_cipher=use_cipher)
# TODO: there are some bugs while state_dict is loaded from veturboio
if not force_save_shared_tensor:
try:
saver.save_file(state_dict, file, metadata=metadata)
saver.save_file(state_dict, metadata=metadata)
except ValueError as e:
msg = str(e)
raise ValueError(msg)
......@@ -154,7 +156,7 @@ def save_file(
if force_contiguous:
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
return saver.save_file(state_dict, file, metadata=metadata)
return saver.save_file(state_dict, metadata=metadata)
def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[bool] = False) -> None:
......@@ -177,11 +179,11 @@ def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[boo
use_sfcs_sdk, file = is_sfcs_path(file)
if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher)
saver = SfcsClientSaver(file=file, use_cipher=use_cipher)
else:
saver = PosixSaver(use_cipher=use_cipher)
saver = PosixSaver(file=file, use_cipher=use_cipher)
return saver.save_model(model, file)
return saver.save_model(model)
def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Optional[bool] = False) -> None:
......@@ -203,8 +205,8 @@ def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Op
"""
use_sfcs_sdk, file = is_sfcs_path(file)
if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher)
saver = SfcsClientSaver(file=file, use_cipher=use_cipher)
else:
saver = PosixSaver(use_cipher=use_cipher)
saver = PosixSaver(file=file, use_cipher=use_cipher)
return saver.save_pt(state_dict, file)
return saver.save_pt(state_dict)
......@@ -34,9 +34,7 @@ class BaseLoader:
def __init__(self, method: str) -> None:
self.method = method
def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> bytes:
def load_to_bytes(self, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)) -> bytes:
raise NotImplementedError
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
......@@ -65,13 +63,12 @@ class BaseLoader:
class PosixLoader(BaseLoader):
def __init__(self) -> None:
def __init__(self, file: FILE_PATH) -> None:
super().__init__(method="posix")
self.file = file
def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> bytes:
arr = np.fromfile(file, dtype=np.uint8, offset=offset, count=count)
def load_to_bytes(self, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)) -> bytes:
arr = np.fromfile(self.file, dtype=np.uint8, offset=offset, count=count)
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
decrypt(cipher_info, arr, arr, offset - h_off)
......@@ -107,12 +104,12 @@ class PosixLoader(BaseLoader):
return state_dict
def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
self, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]:
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
arr = np.fromfile(file, dtype=np.uint8, offset=h_off, count=-1)
arr = np.fromfile(self.file, dtype=np.uint8, offset=h_off, count=-1)
decrypt(cipher_info, arr, arr, 0)
return torch.load(io.BytesIO(arr.data), map_location=map_location)
return torch.load(file, map_location=map_location)
return torch.load(self.file, map_location=map_location)
......@@ -32,12 +32,13 @@ from .base_loader import PosixLoader
class FasterPosixLoader(PosixLoader):
def __init__(
self,
file: FILE_PATH,
helper: IOHelper,
num_thread: int = 32,
use_pinmem: bool = False,
use_direct_io: bool = False,
) -> None:
super().__init__()
super().__init__(file)
self.helper = helper
self.num_thread = num_thread
self.use_pinmem = use_pinmem
......@@ -72,12 +73,12 @@ class FasterPosixLoader(PosixLoader):
return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file)
def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
self, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]:
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
arr = np.fromfile(file, dtype=np.uint8, offset=h_off, count=-1)
arr = np.fromfile(self.file, dtype=np.uint8, offset=h_off, count=-1)
decrypt(cipher_info, arr, arr, 0)
return torch.load(io.BytesIO(arr.data), map_location=map_location)
return torch.load(file, map_location=map_location)
return torch.load(self.file, map_location=map_location)
......@@ -33,6 +33,7 @@ from veturboio.types import FILE_PATH
class SfcsClientLoader(BaseLoader):
def __init__(
self,
file: FILE_PATH,
helper: IOHelper,
num_thread: int = 32,
use_pinmem: bool = False,
......@@ -40,24 +41,23 @@ class SfcsClientLoader(BaseLoader):
) -> None:
super().__init__(method="client")
self.file = file
self.helper = helper
self.num_thread = num_thread
self.use_pinmem = use_pinmem
self.use_direct_io = use_direct_io
init_sfcs_conf()
init_sfcs_conf(file)
def load_to_bytes(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> bytes:
file_size = sfcs_get_file_size(file)
def load_to_bytes(self, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)) -> bytes:
file_size = sfcs_get_file_size(self.file)
if offset + count > file_size:
count = file_size - offset
file_bytes = bytes(count)
candidate = np.frombuffer(file_bytes, dtype=np.byte)
sfcs_read_file(
file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info
self.file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info
)
return file_bytes
......@@ -89,9 +89,9 @@ class SfcsClientLoader(BaseLoader):
return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file)
def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
self, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]:
file_size = sfcs_get_file_size(file)
file_size = sfcs_get_file_size(self.file)
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
file_bytes = self.load_to_bytes(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
file_bytes = self.load_to_bytes(offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
return torch.load(BytesIO(file_bytes), map_location=map_location)
......@@ -14,18 +14,20 @@ See the License for the specific language governing permissions and
limitations under the License.
'''
import json
import os
import shutil
import tempfile
import threading
import xml.dom.minidom
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, Tuple
import numpy as np
from loguru import logger
from veturboio.ops.cipher import CipherInfo, DataPipeClient
from veturboio.types import FILE_PATH
try:
from veturboio.utils.load_veturboio_ext import load_veturboio_ext
......@@ -36,7 +38,6 @@ except ImportError:
SFCSFile = None
logger.warning("veturboio_ext not found, fallback to pure python implementation")
SFCS_REQ_ENV_LIST = [
'SFCS_FSNAME',
'SFCS_REGION',
......@@ -54,51 +55,48 @@ SFCS_OPT_ENV_LIST = [
'SFCS_NAMENODE_ENDPOINT_ADDRESS',
]
SFCS_PROPERTIES = {
'cfs.filesystem.fs-mode': 'ACC',
'cfs.filesystem.task-id': 'sfcs',
'cfs.filesystem.resolve.addr.by.dns': 'false',
'cfs.metrics.emitters': 'metric_server;local_prometheus',
'cfs.client.metadata-cache.enable': 'false',
'rpc.client.channel.pool.size': '32',
'dfs.default.replica': '2',
'cfs.client.multi-nic.enabled': 'true',
'fs.datanode.router.ignore-main-nic': 'true',
'cfs.datanode.router.shuffle': 'true',
}
def default_sfcs_properties() -> dict:
return {
'cfs.filesystem.fs-mode': 'ACC',
'cfs.filesystem.task-id': 'sfcs',
'cfs.filesystem.resolve.addr.by.dns': 'false',
'cfs.metrics.emitters': 'metric_server;local_prometheus',
'cfs.client.metadata-cache.enable': 'false',
'rpc.client.channel.pool.size': '32',
'dfs.default.replica': '2',
'cfs.client.multi-nic.enabled': 'true',
'fs.datanode.router.ignore-main-nic': 'true',
'cfs.datanode.router.shuffle': 'true',
}
class CredentialsHelper:
def __init__(self):
self.lock = threading.Lock()
self.running = False
self.running = {}
# daemon thread will stop when parent thread exits
self.thread = None
self.threads = {}
self.stop_flag = threading.Event()
self.client = None
self.current_time = 0
self.expired_time = 0
self.ak = None
self.sk = None
self.st = None
self.name_node_ip = None
self.sfcs_conf_path = None
self.stop_flag = None
def run(self, sfcs_conf_path) -> None:
if not self.running:
def run(self, group: str, sfcs_conf_path: str) -> None:
if not self.running.get(group, False):
with self.lock:
if not self.running:
self.thread = threading.Thread(target=self.refresh_loop, daemon=True)
self.stop_flag = threading.Event()
self.client = DataPipeClient()
self.sfcs_conf_path = sfcs_conf_path
if not self.do_refresh():
raise RuntimeError('Credentials helper first fetch failed')
self.thread.start()
self.running = True
logger.info('CredentialsHelper refresh thread strat')
if not self.running.get(group, False):
if self.client is None:
self.client = DataPipeClient()
init_ts = self.do_refresh(group, sfcs_conf_path)
if not init_ts:
raise RuntimeError(f'Credentials helper for {sfcs_conf_path} first fetch failed')
self.threads[group] = threading.Thread(
target=self.refresh_loop, args=(group, sfcs_conf_path, init_ts[0], init_ts[1]), daemon=True
)
self.threads[group].start()
self.running[group] = True
logger.info(f'CredentialsHelper refresh thread for {sfcs_conf_path} start')
return
logger.info('CredentialsHelper thread is already running, do nothing')
logger.info(f'CredentialsHelper thread for {sfcs_conf_path} is already running, do nothing')
def stop(self):
self.stop_flag.set()
......@@ -115,77 +113,92 @@ class CredentialsHelper:
return False
return True
def refresh_loop(self) -> None:
def refresh_loop(self, group: str, sfcs_conf_path: str, current_time: float, expired_time: float) -> None:
while True:
now = datetime.now(tz=timezone.utc).timestamp()
ts_ref = (self.current_time + self.expired_time) / 2
ts_ref = (current_time + expired_time) / 2
if now >= ts_ref:
if not self.do_refresh():
raise RuntimeError('Credentials helper do refresh failed')
ts = self.do_refresh(group, sfcs_conf_path)
if not ts:
raise RuntimeError(f'Credentials helper do refresh at {sfcs_conf_path} failed')
current_time, expired_time = ts[0], ts[1]
else:
if self.stop_flag.wait(ts_ref - now):
return
def do_refresh(self) -> bool:
def do_refresh(self, group: str, sfcs_conf_path: str) -> Optional[Tuple[float, float]]:
d = self.client.get_sfcs_ak_sk_st()
if self.is_valid_res(d):
self.name_node_ip = d['SfcsNameNodeAddress']
name_node_ip = d['SfcsNameNodeAddress']
d = d['Cred']
self.current_time = datetime.fromisoformat(d['CurrentTime']).timestamp()
self.expired_time = datetime.fromisoformat(d['ExpiredTime']).timestamp()
self.ak = d['AccessKeyId']
self.sk = d['SecretAccessKey']
self.st = d['SessionToken']
# update SFCS_PROPERTIES and then write xml
SFCS_PROPERTIES['cfs.access.key'] = self.ak
SFCS_PROPERTIES['cfs.secret.key'] = self.sk
SFCS_PROPERTIES['cfs.security.token'] = self.st
SFCS_PROPERTIES['cfs.namenode.endpoint.address.' + os.getenv('SFCS_FSNAME')] = self.name_node_ip
generate_sfcs_conf_xml(self.sfcs_conf_path)
logger.info('Credentials are successfully refreshed!')
return True
current_time = datetime.fromisoformat(d['CurrentTime']).timestamp()
expired_time = datetime.fromisoformat(d['ExpiredTime']).timestamp()
ak = d['AccessKeyId']
sk = d['SecretAccessKey']
st = d['SessionToken']
try:
sfcs_fsname = json.loads(os.getenv('SFCS_FSNAME'))[group]
except:
sfcs_fsname = os.getenv('SFCS_FSNAME')
properties = init_sfcs_properties(group)
properties['cfs.access.key'] = ak
properties['cfs.secret.key'] = sk
properties['cfs.security.token'] = st
properties['cfs.namenode.endpoint.address.' + sfcs_fsname] = name_node_ip
generate_sfcs_conf_xml(sfcs_conf_path, properties)
logger.info(f'Credentials are successfully refreshed at {sfcs_conf_path}!')
return current_time, expired_time
else:
return False
return None
credentials_helper = CredentialsHelper()
def init_sfcs_properties():
def init_sfcs_properties(group: str) -> dict:
for env in SFCS_REQ_ENV_LIST:
if os.getenv(env) is None:
raise ValueError('environ ' + env + ' not set')
SFCS_PROPERTIES['dfs.default.uri'] = (
'cfs://' + os.getenv('SFCS_FSNAME') + '.sfcs-' + os.getenv('SFCS_REGION') + '.ivolces.com'
)
SFCS_PROPERTIES['dfs.authentication.service.name'] = os.getenv('SFCS_AUTHENTICATION_SERVICE_NAME')
SFCS_PROPERTIES['cfs.filesystem.ns-id'] = os.getenv('SFCS_NS_ID')
SFCS_PROPERTIES['cfs.filesystem.ufs-path'] = os.getenv('SFCS_UFS_PATH')
SFCS_PROPERTIES['cfs.metrics.server.host'] = 'metricserver.cfs-' + os.getenv('SFCS_REGION') + '.ivolces.com'
SFCS_PROPERTIES['cfs.client.multi-nic.whitelist'] = os.getenv('SFCS_MULTI_NIC_WHITELIST')
SFCS_PROPERTIES['cfs.client.network.segment'] = os.getenv('SFCS_NETWORK_SEGMENT')
SFCS_PROPERTIES['dfs.client.log.severity'] = os.getenv('SFCS_LOG_SEVERITY')
try:
sfcs_fsname = json.loads(os.getenv('SFCS_FSNAME'))[group]
sfcs_ns_id = json.loads(os.getenv('SFCS_NS_ID'))[group]
sfcs_ufs_path = json.loads(os.getenv('SFCS_UFS_PATH'))[group]
logger.info(f"parse sfcs fsname, ns_id and ufs_path from environ in JSON format")
except:
sfcs_fsname = os.getenv('SFCS_FSNAME')
sfcs_ns_id = os.getenv('SFCS_NS_ID')
sfcs_ufs_path = os.getenv('SFCS_UFS_PATH')
logger.info(f"parse sfcs fsname, ns_id and ufs_path from environ in STRING format")
properties = default_sfcs_properties()
properties['dfs.default.uri'] = 'cfs://' + sfcs_fsname + '.sfcs-' + os.getenv('SFCS_REGION') + '.ivolces.com'
properties['dfs.authentication.service.name'] = os.getenv('SFCS_AUTHENTICATION_SERVICE_NAME')
properties['cfs.filesystem.ns-id'] = sfcs_ns_id
properties['cfs.filesystem.ufs-path'] = sfcs_ufs_path
properties['cfs.metrics.server.host'] = 'metricserver.cfs-' + os.getenv('SFCS_REGION') + '.ivolces.com'
properties['cfs.client.multi-nic.whitelist'] = os.getenv('SFCS_MULTI_NIC_WHITELIST')
properties['cfs.client.network.segment'] = os.getenv('SFCS_NETWORK_SEGMENT')
properties['dfs.client.log.severity'] = os.getenv('SFCS_LOG_SEVERITY')
# optional
SFCS_PROPERTIES['cfs.filesystem.sync-interval'] = os.getenv('SFCS_SYNC_INTERVAL', "-1")
SFCS_PROPERTIES['cfs.access.key'] = os.getenv('SFCS_ACCESS_KEY')
SFCS_PROPERTIES['cfs.secret.key'] = os.getenv('SFCS_SECRET_KEY')
SFCS_PROPERTIES['cfs.namenode.endpoint.address.' + os.getenv('SFCS_FSNAME')] = os.getenv(
'SFCS_NAMENODE_ENDPOINT_ADDRESS'
)
properties['cfs.filesystem.sync-interval'] = os.getenv('SFCS_SYNC_INTERVAL', "-1")
properties['cfs.access.key'] = os.getenv('SFCS_ACCESS_KEY')
properties['cfs.secret.key'] = os.getenv('SFCS_SECRET_KEY')
properties['cfs.namenode.endpoint.address.' + sfcs_fsname] = os.getenv('SFCS_NAMENODE_ENDPOINT_ADDRESS')
return properties
def generate_sfcs_conf_xml(sfcs_conf):
def generate_sfcs_conf_xml(sfcs_conf: FILE_PATH, sfcs_properties: dict):
doc = xml.dom.minidom.Document()
configuration = doc.createElement('configuration')
doc.appendChild(configuration)
for key in SFCS_PROPERTIES:
for key in sfcs_properties:
property = doc.createElement('property')
name = doc.createElement('name')
name.appendChild(doc.createTextNode(key))
value = doc.createElement('value')
value.appendChild(doc.createTextNode(SFCS_PROPERTIES[key]))
value.appendChild(doc.createTextNode(sfcs_properties[key]))
property.appendChild(name)
property.appendChild(value)
......@@ -200,29 +213,33 @@ def generate_sfcs_conf_xml(sfcs_conf):
shutil.move(tmp_conf.name, sfcs_conf)
def init_sfcs_conf():
if not os.getenv('LIBCFS_CONF'):
logger.warning('environ LIBCFS_CONF not set, set it to ' + os.getcwd() + '/libcfs.xml')
os.environ['LIBCFS_CONF'] = os.getcwd() + '/libcfs.xml'
def sfcs_conf_group(file: FILE_PATH) -> str:
return os.path.abspath(file).split("/")[1]
def init_sfcs_conf(file: FILE_PATH):
group = sfcs_conf_group(file)
sfcs_conf = os.path.join(os.getcwd(), group + '.xml')
os.environ['LIBCFS_CONF'] = sfcs_conf
logger.info(f'environ LIBCFS_CONF set to {sfcs_conf}')
sfcs_conf = os.getenv('LIBCFS_CONF')
if os.path.exists(sfcs_conf):
if os.path.isfile(sfcs_conf):
# case 1: a xml file already exists, do nothing
logger.warning('LIBCFS_CONF file exists')
logger.info('LIBCFS_CONF file exists')
else:
init_sfcs_properties()
if (
os.getenv('SFCS_ACCESS_KEY')
and os.getenv('SFCS_SECRET_KEY')
and os.getenv('SFCS_NAMENODE_ENDPOINT_ADDRESS')
):
# case 2: env SFCS_ACCESS_KEY, SFCS_SECRET_KEY and SFCS_NAMENODE_ENDPOINT_ADDRESS exist
logger.warning('Use aksk and namenode_ip in env to generate sfcs config')
generate_sfcs_conf_xml(sfcs_conf)
logger.info('Use aksk and namenode_ip in env to generate sfcs config')
properties = init_sfcs_properties(group)
generate_sfcs_conf_xml(sfcs_conf, properties)
else:
# case 3: use datapipe socket to get and refresh ak, sk, and st
logger.warning('Use credentials helper to generate and update sfcs config')
credentials_helper.run(sfcs_conf)
logger.info('Use credentials helper to generate and update sfcs config')
credentials_helper.run(group, sfcs_conf)
def sfcs_get_file_size(file_path: str) -> int:
......
......@@ -106,7 +106,7 @@ class SafetensorsFile:
# cipher related
self._cipher_info = CipherInfo(False)
if use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1":
header_bytes = loader.load_to_bytes(file, offset=0, count=CipherInfo.HEADER_SIZE)
header_bytes = loader.load_to_bytes(offset=0, count=CipherInfo.HEADER_SIZE)
self._cipher_info = CipherInfo(True, header_bytes)
if self._cipher_info.use_header:
......@@ -114,15 +114,15 @@ class SafetensorsFile:
else:
h_off = 0
magic_number = loader.load_to_bytes(file, offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
magic_number = loader.load_to_bytes(offset=8 + h_off, count=1, cipher_info=self._cipher_info)[0]
if magic_number != SAFETENSORS_FILE_MAGIC_NUM:
self._is_valid = False
return
self._meta_size = np.frombuffer(
loader.load_to_bytes(file, offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
loader.load_to_bytes(offset=h_off, count=8, cipher_info=self._cipher_info), dtype=np.int64
)[0]
meta_bytes = loader.load_to_bytes(file, offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info)
meta_bytes = loader.load_to_bytes(offset=8 + h_off, count=self._meta_size, cipher_info=self._cipher_info)
meta_dict = json.loads(meta_bytes.decode("utf-8"))
self._shared_tensor = {}
......@@ -208,6 +208,6 @@ class SafetensorsFile:
def load(self, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
if not self._is_valid:
return self._loader.load_pt(self.file, map_location, self._cipher_info)
return self._loader.load_pt(map_location, self._cipher_info)
else:
return self._loader.load_safetensors(self, map_location)
......@@ -39,8 +39,9 @@ class BaseSaver:
class PosixSaver(BaseSaver):
def __init__(self, use_cipher: bool = False) -> None:
def __init__(self, file: FILE_PATH, use_cipher: bool = False) -> None:
super().__init__(method="posix")
self.file = file
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
if use_header:
......@@ -48,7 +49,7 @@ class PosixSaver(BaseSaver):
else:
self.cipher_info = CipherInfo(use_cipher)
def save_file(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH, metadata: Dict[str, str] = None) -> None:
def save_file(self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None) -> None:
if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name
......@@ -56,15 +57,15 @@ class PosixSaver(BaseSaver):
tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
file_bytes = np.memmap(self.file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush()
else:
safetenors_save_file(state_dict, file, metadata=metadata)
safetenors_save_file(state_dict, self.file, metadata=metadata)
def save_model(self, model: torch.nn.Module, file: FILE_PATH) -> None:
def save_model(self, model: torch.nn.Module) -> None:
if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name
......@@ -72,15 +73,15 @@ class PosixSaver(BaseSaver):
tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
file_bytes = np.memmap(self.file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush()
else:
safetensors_save_model(model, file)
safetensors_save_model(model, self.file)
def save_pt(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH) -> None:
def save_pt(self, state_dict: Dict[str, torch.Tensor]) -> None:
if self.cipher_info.use_cipher:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
tmp_file_path = tmpfile.name
......@@ -88,10 +89,10 @@ class PosixSaver(BaseSaver):
tmp_file_size = os.path.getsize(tmp_file_path)
tmp_file_bytes = np.memmap(tmp_file_path, dtype=np.uint8, mode='r', shape=tmp_file_size)
h_off = CipherInfo.HEADER_SIZE if self.cipher_info.use_header else 0
file_bytes = np.memmap(file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
file_bytes = np.memmap(self.file, dtype=np.uint8, mode='w+', shape=tmp_file_size + h_off)
encrypt(self.cipher_info, tmp_file_bytes, file_bytes[h_off:], 0)
if h_off:
file_bytes[:h_off] = np.frombuffer(self.cipher_info.to_header_bytes(), dtype=np.uint8)
file_bytes.flush()
else:
torch.save(state_dict, file)
torch.save(state_dict, self.file)
......@@ -30,10 +30,11 @@ from veturboio.types import FILE_PATH
class SfcsClientSaver(BaseSaver):
def __init__(self, use_cipher: bool = False) -> None:
def __init__(self, file: FILE_PATH, use_cipher: bool = False) -> None:
super().__init__(method="client")
init_sfcs_conf()
self.file = file
init_sfcs_conf(file)
use_cipher = use_cipher or os.getenv("VETURBOIO_USE_CIPHER", "0") == "1"
use_header = use_cipher and os.getenv("VETURBOIO_CIPHER_HEADER", "0") == "1"
......@@ -42,7 +43,7 @@ class SfcsClientSaver(BaseSaver):
else:
self.cipher_info = CipherInfo(use_cipher)
def save_file(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH, metadata: Dict[str, str] = None) -> None:
def save_file(self, state_dict: Dict[str, torch.Tensor], metadata: Dict[str, str] = None) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name
safetenors_save_file(state_dict, file_path, metadata=metadata)
......@@ -55,9 +56,9 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info)
sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info)
def save_model(self, model: torch.nn.Module, file: FILE_PATH) -> None:
def save_model(self, model: torch.nn.Module) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name
safetensors_save_model(model, file_path)
......@@ -70,9 +71,9 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info)
sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info)
def save_pt(self, state_dict: Dict[str, torch.Tensor], file: FILE_PATH) -> None:
def save_pt(self, state_dict: Dict[str, torch.Tensor]) -> None:
with tempfile.NamedTemporaryFile(dir="/dev/shm") as tmpfile:
file_path = tmpfile.name
torch.save(state_dict, file_path)
......@@ -85,4 +86,4 @@ class SfcsClientSaver(BaseSaver):
file_bytes[h_off:] = np.fromfile(file_path, dtype=np.byte, count=file_size)
else:
file_bytes = np.memmap(file_path, dtype=np.byte, mode='r+', shape=file_size)
sfcs_write_file(file, file_bytes, len(file_bytes), self.cipher_info)
sfcs_write_file(self.file, file_bytes, len(file_bytes), self.cipher_info)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment