"examples/wav2vec/unsupervised/vscode:/vscode.git/clone" did not exist on "c394d7d13ebd01f9e1eac2a1c18a79e8867b2a2d"
Commit e5edc542 authored by 刘乐典's avatar 刘乐典 Committed by huteng.ht
Browse files

feat(security): compat with cipher header and use cipher in posix

* feat(security): compat with cipher header and use cipher in posix
* feat(security): compat with cipher header and use cipher in posix
parent 0cedada8
......@@ -71,18 +71,39 @@ for k, v in tensors2.items():
assert torch.allclose(v.cuda(), reloaded_tensor2[k])
```
### 使用SFCS读写模型启用加解密
该库可以读写 SFCS 上的模型文件,在此情况下可以启用加解密能力。读写 SFCS 上的模型文件和加解密所需的敏感信息,有两种获取方式:(1) 火山引擎可信服务的 unix domain socket,(2) 环境变量。在没有挂载可信服务 uds 的情况下,可以使用下面的环境变量:
### 读写模型启用加解密
该库底层通过两种接口读写SFCS SDK 和 POSIX。如果文件路径前缀为 `sfcs://` 就视为使用 SFCS SDK,所需的鉴权信息可以从火山引擎可信服务的 unix domain socket 获取, 或者设置以下三个环境变量:
| 环境变量名 | 含义 |
| ------------------------------ | --------------------------------- |
| SFCS_ACCESS_KEY | SFCS文件系统的AK |
| SFCS_SECRET_KEY | SFCS文件系统的SK |
| SFCS_NAMENODE_ENDPOINT_ADDRESS | SFCS文件系统name节点地址 |
| VETUROIO_KEY | 加解密的128位数据密钥的base64编码 |
| VETUROIO_IV | 加解密的128位初始向量的base64编码 |
| SFCS_ACCESS_KEY | SFCS 文件系统的 AK |
| SFCS_SECRET_KEY | SFCS 文件系统的 SK |
| SFCS_NAMENODE_ENDPOINT_ADDRESS | SFCS 文件系统 name 节点地址 |
挂载可信服务 uds 或者配置好环境变量后,可以参考下面代码在读写 SFCS 上的模型文件时启用加解密:
加解密读写模型文件所需的 data key 和 iv,共有3种获取方式,优先级按照序号:
- [1] 加密的 data key 和 iv 存放在密文模型文件的 header 中,使用火山引擎 KMS 解密得到明文的 data key。
- [1.1] 访问 KMS 所需的 AK/SK/ST 从火山引擎可信服务的 unix domain socket 获取,需要额外挂载。
- [1.2] 访问 KMS 所需的 AK/SK/ST 从环境变量获取。
- [2] 访问火山引擎可信服务的 unix domain socket 直接获取 data key 和 iv,需要额外挂载。
- [3] 环境变量直接设置 data key 和 iv。
不同方式需要设置的环境变量如下:
| 环境变量名 | 含义 |
| ------------------------------ | --------------------------------- |
| VETURBOIO_KMS_HOST | [1] KMS 服务地址,默认值 open.volcengineapi.com|
| VETURBOIO_KMS_REGION | [1] KMS 服务所在区域,默认值 cn-beijing |
| VETURBOIO_KMS_KEYRING_NAME | [1] KMS 服务解密 data key 的钥匙环名 |
| VETURBOIO_KMS_KEY_NAME | [1] KMS 服务解密 data key 的主密钥名 |
| DATAPIPE_SOCKET_PATH | [1.1][2] 可信服务 uds 的路径 |
| VETURBOIO_KMS_ACCESS_KEY | [1.2] KMS 鉴权的 AK |
| VETURBOIO_KMS_SECRET_KEY | [1.2] KMS 鉴权的 SK |
| VETURBOIO_KMS_SESSION_TOKEN | [1.2] KMS 鉴权的临时令牌,非必需|
| VETURBOIO_KEY | [3] 加解密的 128 位数据密钥的 base64 编码 |
| VETURBOIO_IV | [3] 加解密的 128 位初始向量的 base64 编码 |
按照上述三种方式设置好后,可以参考下面代码在读写模型文件时启用加解密:
```python
import torch
import veturboio
......
......@@ -80,7 +80,7 @@ def get_veturboio_extension():
include_dirs = ["veturboio/ops/csrc/include"]
library_dirs = ["veturboio/ops/csrc/lib"]
libraries = ["cfs", "fastcrypto"]
libraries = ["cfs", ":libfastcrypto_gpu.so.0.3"]
extra_link_args = [make_relative_rpath("veturboio/ops/csrc/lib")]
return CUDAExtension(
......@@ -90,6 +90,7 @@ def get_veturboio_extension():
"veturboio/ops/csrc/load_utils.cpp",
"veturboio/ops/csrc/sfcs.cpp",
"veturboio/ops/csrc/io_helper.cu",
"veturboio/ops/csrc/cipher.cpp",
],
define_macros=define_macros,
include_dirs=include_dirs,
......
......@@ -44,8 +44,28 @@ class UnixSocketHttpServer(socketserver.UnixStreamServer):
class DatapipeHandler(http.server.SimpleHTTPRequestHandler):
def do_POST(self):
action = self.headers.get('X-Datapipe-Task-Type')
if action == 'top':
# mock kms response
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
res = {'Result': {'Plaintext': base64.b64encode(b'abcdefgh87654321').decode('ascii')}}
self.wfile.write(bytes(json.dumps(res), encoding='ascii'))
return
self.send_response(400)
self.end_headers()
return
def do_GET(self):
action = self.headers.get('X-Datapipe-Task-Type')
if action == 'ping':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(bytes(json.dumps({'message': 'pong'}), encoding='ascii'))
return
if action == 'encrypt-key':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
......@@ -74,6 +94,19 @@ class DatapipeHandler(http.server.SimpleHTTPRequestHandler):
}
self.wfile.write(bytes(json.dumps(res), encoding='ascii'))
return
if action == 'kms-sts':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
res = {
'Cred': {
'AccessKeyId': os.environ['CI_VENDOR_AK'],
'SecretAccessKey': os.environ['CI_VENDOR_AK'],
'SessionToken': '',
},
}
self.wfile.write(bytes(json.dumps(res), encoding='ascii'))
return
self.send_response(400)
self.end_headers()
return
......@@ -92,8 +125,55 @@ class TestCipherInfo(TestCase):
cls.thread = threading.Thread(target=run)
cls.thread.start()
cls.target_key = np.frombuffer(b'abcdefgh12345678', dtype=np.byte)
cls.target_key_2 = np.frombuffer(b'abcdefgh87654321', dtype=np.byte)
cls.target_iv = np.frombuffer(b'1234567887654321', dtype=np.byte)
def test_fetch_from_file_header(self):
os.environ.pop('VETURBOIO_KEY', None)
os.environ.pop('VETURBOIO_IV', None)
DataPipeClient.DATAPIPE_SOCKET_PATH = '/path/not/exist'
header_dict = {
'mode': 'CTR-128',
'iv': 'MTIzNDU2Nzg4NzY1NDMyMQ==',
'meta_data_key': 'bl2htKYLQ2+CjyyJ84Q3twAA9ZpCbFxwznRb0NkR9zGGRp1RK5Mb9u8NNOiahY+0yVrxNw3IVQ9Wgn6PDscw77Cb3eImjVn14hNBJRlwtSyQ7tRZLOsZBEHv5cWwDQ==',
}
header_bytes = bytearray(256 * 1024)
header_str = 'Byte3ncryptM0del' + json.dumps(header_dict)
header_bytes[: len(header_str)] = header_str.encode('utf-8')
# case1: get kms cred from env
ENV_KMS_HOST = 'VETURBOIO_KMS_HOST'
ENV_KMS_REGION = 'VETURBOIO_KMS_REGION'
ENV_KMS_AK = 'VETURBOIO_KMS_ACCESS_KEY'
ENV_KMS_SK = 'VETURBOIO_KMS_SECRET_KEY'
ENV_KMS_KEYRING = 'VETURBOIO_KMS_KEYRING_NAME'
ENV_KMS_KEY = 'VETURBOIO_KMS_KEY_NAME'
os.environ[ENV_KMS_HOST] = 'open.volcengineapi.com'
os.environ[ENV_KMS_REGION] = 'cn-beijing'
os.environ[ENV_KMS_AK] = os.environ['CI_VENDOR_AK']
os.environ[ENV_KMS_SK] = os.environ['CI_VENDOR_SK']
os.environ[ENV_KMS_KEYRING] = 'datapipe_keyring'
os.environ[ENV_KMS_KEY] = 'datapipe_key_ml_maas'
info = CipherInfo(True, header_bytes)
self.assertTrue(info.use_cipher)
self.assertTrue(info.use_header)
self.assertTrue(np.array_equal(info.key, self.target_key))
self.assertTrue(np.array_equal(info.iv, self.target_iv))
# case2: get kms cred from datapipe and access kms with datapipe proxy
os.environ.pop(ENV_KMS_HOST, None)
os.environ.pop(ENV_KMS_REGION, None)
os.environ.pop(ENV_KMS_AK, None)
os.environ.pop(ENV_KMS_SK, None)
DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address
info = CipherInfo(True, header_bytes)
info = CipherInfo(True, header_bytes)
self.assertTrue(info.use_cipher)
self.assertTrue(info.use_header)
self.assertTrue(np.array_equal(info.key, self.target_key_2))
self.assertTrue(np.array_equal(info.iv, self.target_iv))
def test_fetch_from_datapipe(self):
DataPipeClient.DATAPIPE_SOCKET_PATH = self.server_address
info = CipherInfo(True)
......@@ -103,8 +183,8 @@ class TestCipherInfo(TestCase):
def test_fetch_from_env(self):
DataPipeClient.DATAPIPE_SOCKET_PATH = '/path/not/exist'
os.environ['VETUROIO_KEY'] = base64.b64encode(b'abcdefgh12345678').decode('ascii')
os.environ['VETUROIO_IV'] = base64.b64encode(b'1234567887654321').decode('ascii')
os.environ['VETURBOIO_KEY'] = base64.b64encode(b'abcdefgh12345678').decode('ascii')
os.environ['VETURBOIO_IV'] = base64.b64encode(b'1234567887654321').decode('ascii')
info = CipherInfo(True)
self.assertTrue(info.use_cipher)
self.assertTrue(np.array_equal(info.key, self.target_key))
......@@ -112,15 +192,15 @@ class TestCipherInfo(TestCase):
def test_fallback(self):
DataPipeClient.DATAPIPE_SOCKET_PATH = '/path/not/exist'
os.environ['VETUROIO_KEY'] = base64.b64encode(b'abcdefgh12').decode('ascii')
os.environ['VETUROIO_IV'] = base64.b64encode(b'1234567887').decode('ascii')
os.environ['VETURBOIO_KEY'] = base64.b64encode(b'abcdefgh12').decode('ascii')
os.environ['VETURBOIO_IV'] = base64.b64encode(b'1234567887').decode('ascii')
info = CipherInfo(True)
self.assertFalse(info.use_cipher)
@classmethod
def tearDownClass(cls):
os.environ.pop('VETUROIO_KEY', None)
os.environ.pop('VETUROIO_IV', None)
os.environ.pop('VETURBOIO_KEY', None)
os.environ.pop('VETURBOIO_IV', None)
cls.server.shutdown()
cls.server.server_close()
cls.thread.join()
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
'''
import base64
import os
import tempfile
import unittest
......@@ -28,6 +29,19 @@ import veturboio
class TestLoad(TestCase):
@classmethod
def setUpClass(cls):
ENV_KMS_HOST = 'VETURBOIO_KMS_HOST'
ENV_KMS_REGION = 'VETURBOIO_KMS_REGION'
ENV_KMS_AK = 'VETURBOIO_KMS_ACCESS_KEY'
ENV_KMS_SK = 'VETURBOIO_KMS_SECRET_KEY'
ENV_KMS_KEYRING = 'VETURBOIO_KMS_KEYRING_NAME'
ENV_KMS_KEY = 'VETURBOIO_KMS_KEY_NAME'
os.environ[ENV_KMS_HOST] = 'open.volcengineapi.com'
os.environ[ENV_KMS_REGION] = 'cn-beijing'
os.environ[ENV_KMS_AK] = os.environ['CI_VENDOR_AK']
os.environ[ENV_KMS_SK] = os.environ['CI_VENDOR_SK']
os.environ[ENV_KMS_KEYRING] = 'datapipe_keyring'
os.environ[ENV_KMS_KEY] = 'datapipe_key_ml_maas'
cls.tempdir = tempfile.TemporaryDirectory()
cls.tensors_0 = {
......@@ -49,6 +63,26 @@ class TestLoad(TestCase):
cls.pt_filepath = os.path.join(cls.tempdir.name, "model.pt")
torch.save(cls.tensors_0, cls.pt_filepath)
# cipher
os.environ["VETURBOIO_KEY"] = base64.b64encode(b"abcdefgh12345678").decode("ascii")
os.environ["VETURBOIO_IV"] = base64.b64encode(b"1234567887654321").decode("ascii")
cls.filepath_0_enc = os.path.join(cls.tempdir.name, "model_0_enc.safetensors")
cls.filepath_1_enc = os.path.join(cls.tempdir.name, "model_1_enc.safetensors")
veturboio.save_file(cls.tensors_0, cls.filepath_0_enc, use_cipher=True)
veturboio.save_file(cls.tensors_1, cls.filepath_1_enc, use_cipher=True)
cls.pt_filepath_enc = os.path.join(cls.tempdir.name, "model_enc.pt")
veturboio.save_pt(cls.tensors_0, cls.pt_filepath_enc, use_cipher=True)
# cipher with header
os.environ["VETURBOIO_CIPHER_HEADER"] = "1"
cls.filepath_0_enc_h = os.path.join(cls.tempdir.name, "model_0_enc_h.safetensors")
veturboio.save_file(cls.tensors_0, cls.filepath_0_enc_h, use_cipher=True)
cls.pt_filepath_enc_h = os.path.join(cls.tempdir.name, "model_enc_h.pt")
veturboio.save_pt(cls.tensors_0, cls.pt_filepath_enc_h, use_cipher=True)
if torch.cuda.is_available():
cls.cuda_tensors_0 = deepcopy(cls.tensors_0)
cls.cuda_tensors_1 = deepcopy(cls.tensors_1)
......@@ -60,44 +94,93 @@ class TestLoad(TestCase):
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
# cls.tempdir.cleanup()
pass
def _run_pipeline(self, tensors, filepath, map_location):
loaded_tensors = veturboio.load(filepath, map_location=map_location)
def _run_pipeline(self, tensors, filepath, map_location, use_cipher, enable_fast_mode=True):
loaded_tensors = veturboio.load(
filepath, map_location=map_location, use_cipher=use_cipher, enable_fast_mode=enable_fast_mode
)
for key in tensors.keys():
self.assertTrue(torch.allclose(tensors[key], loaded_tensors[key]))
return loaded_tensors
def test_pipeline_cpu(self):
self._run_pipeline(self.tensors_0, self.filepath_0, "cpu")
self._run_pipeline(self.tensors_0, self.filepath_0, "cpu", use_cipher=False)
self._run_pipeline(self.tensors_0, self.filepath_0_enc, "cpu", use_cipher=True)
self._run_pipeline(self.tensors_0, self.filepath_0, "cpu", use_cipher=False, enable_fast_mode=False)
self._run_pipeline(self.tensors_0, self.filepath_0_enc, "cpu", use_cipher=True, enable_fast_mode=False)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_pipeline_cuda(self):
self._run_pipeline(self.cuda_tensors_0, self.filepath_0, "cuda:0")
self._run_pipeline(self.cuda_tensors_0, self.filepath_0, "cuda:0", use_cipher=False)
self._run_pipeline(self.cuda_tensors_0, self.filepath_0_enc, "cuda:0", use_cipher=True)
self._run_pipeline(self.cuda_tensors_0, self.filepath_0, "cuda:0", use_cipher=False, enable_fast_mode=False)
self._run_pipeline(self.cuda_tensors_0, self.filepath_0_enc, "cuda:0", use_cipher=True, enable_fast_mode=False)
def test_read_multi_state_dict_cpu(self):
load_tensor_0 = self._run_pipeline(self.tensors_0, self.filepath_0, "cpu")
load_tensor_1 = self._run_pipeline(self.tensors_1, self.filepath_1, "cpu")
load_tensor_0 = self._run_pipeline(self.tensors_0, self.filepath_0, "cpu", use_cipher=False)
load_tensor_1 = self._run_pipeline(self.tensors_1, self.filepath_1, "cpu", use_cipher=False)
self.assertEqual(len(load_tensor_0), 2)
self.assertEqual(len(load_tensor_1), 3)
load_tensor_0_enc = self._run_pipeline(self.tensors_0, self.filepath_0_enc, "cpu", use_cipher=True)
load_tensor_1_enc = self._run_pipeline(self.tensors_1, self.filepath_1_enc, "cpu", use_cipher=True)
self.assertEqual(len(load_tensor_0_enc), 2)
self.assertEqual(len(load_tensor_1_enc), 3)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_read_multi_state_dict_cuda(self):
load_tensor_0 = self._run_pipeline(self.cuda_tensors_0, self.filepath_0, "cuda:0")
load_tensor_1 = self._run_pipeline(self.cuda_tensors_1, self.filepath_1, "cuda:0")
load_tensor_0 = self._run_pipeline(self.cuda_tensors_0, self.filepath_0, "cuda:0", use_cipher=False)
load_tensor_1 = self._run_pipeline(self.cuda_tensors_1, self.filepath_1, "cuda:0", use_cipher=False)
self.assertEqual(len(load_tensor_0), 2)
self.assertEqual(len(load_tensor_1), 3)
load_tensor_0_enc = self._run_pipeline(self.cuda_tensors_0, self.filepath_0_enc, "cuda:0", use_cipher=True)
load_tensor_1_enc = self._run_pipeline(self.cuda_tensors_1, self.filepath_1_enc, "cuda:0", use_cipher=True)
self.assertEqual(len(load_tensor_0_enc), 2)
self.assertEqual(len(load_tensor_1_enc), 3)
def test_load_pt_cpu(self):
loaded_tensors = veturboio.load(self.pt_filepath, map_location="cpu")
loaded_tensors = veturboio.load(self.pt_filepath, map_location="cpu", use_cipher=False)
for key in self.tensors_0.keys():
self.assertTrue(torch.allclose(self.tensors_0[key], loaded_tensors[key]))
loaded_tensors_enc = veturboio.load(self.pt_filepath_enc, map_location="cpu", use_cipher=True)
for key in self.tensors_0.keys():
self.assertTrue(torch.allclose(self.tensors_0[key], loaded_tensors_enc[key]))
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_load_pt_cuda(self):
loaded_tensors = veturboio.load(self.pt_filepath, map_location="cuda:0")
loaded_tensors = veturboio.load(self.pt_filepath, map_location="cuda:0", use_cipher=False)
for key in self.tensors_0.keys():
self.assertTrue(torch.allclose(self.cuda_tensors_0[key], loaded_tensors[key]))
loaded_tensors_enc = veturboio.load(self.pt_filepath_enc, map_location="cuda:0", use_cipher=True)
for key in self.tensors_0.keys():
self.assertTrue(torch.allclose(self.cuda_tensors_0[key], loaded_tensors_enc[key]))
def test_load_cipher_header_cpu(self):
os.environ["VETURBOIO_CIPHER_HEADER"] = "1"
self._run_pipeline(self.tensors_0, self.filepath_0_enc_h, "cpu", use_cipher=True)
self._run_pipeline(self.tensors_0, self.pt_filepath_enc_h, "cpu", use_cipher=True)
self._run_pipeline(self.tensors_0, self.filepath_0_enc_h, "cpu", use_cipher=True, enable_fast_mode=False)
self._run_pipeline(self.tensors_0, self.pt_filepath_enc_h, "cpu", use_cipher=True, enable_fast_mode=False)
del os.environ["VETURBOIO_CIPHER_HEADER"]
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_load_cipher_header_cuda(self):
os.environ["VETURBOIO_CIPHER_HEADER"] = "1"
self._run_pipeline(self.cuda_tensors_0, self.filepath_0_enc_h, "cuda:0", use_cipher=True)
self._run_pipeline(self.cuda_tensors_0, self.pt_filepath_enc_h, "cuda:0", use_cipher=True)
self._run_pipeline(
self.cuda_tensors_0, self.filepath_0_enc_h, "cuda:0", use_cipher=True, enable_fast_mode=False
)
self._run_pipeline(
self.cuda_tensors_0, self.pt_filepath_enc_h, "cuda:0", use_cipher=True, enable_fast_mode=False
)
del os.environ["VETURBOIO_CIPHER_HEADER"]
......@@ -34,9 +34,19 @@ class TestSave(TestCase):
"weight2": torch.randn(2000, 10),
}
class MockModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(100, 50)
self.linear2 = torch.nn.Linear(100, 50)
cls.model = MockModel()
cls.tempdir = tempfile.TemporaryDirectory()
cls.filepath_0 = os.path.join(cls.tempdir.name, "model_0.safetensors")
cls.filepath_1 = os.path.join(cls.tempdir.name, "model_0.pt")
cls.filepath_3 = os.path.join(cls.tempdir.name, "model_1.safetensors")
@classmethod
def tearDownClass(cls):
......@@ -48,6 +58,13 @@ class TestSave(TestCase):
for key in f.keys():
self.assertTrue(torch.allclose(self.tensors_0[key], f.get_tensor(key)))
def test_save_model(self):
veturboio.save_model(self.model, self.filepath_3, use_cipher=True)
loaded_tensors = veturboio.load(self.filepath_3, map_location="cpu", use_cipher=True)
state_dict = self.model.state_dict()
for key in state_dict.keys():
self.assertTrue(torch.allclose(state_dict[key], loaded_tensors[key]))
def test_save_pt(self):
veturboio.save_pt(self.tensors_0, self.filepath_1)
loaded_tensors = torch.load(self.filepath_1)
......
......@@ -83,24 +83,38 @@ class TestSFCSLoad(TestCase):
def setUpClass(cls):
init_sfcs_env()
os.environ['VETUROIO_KEY'] = base64.b64encode(b'abcdefgh12345678').decode('ascii')
os.environ['VETUROIO_IV'] = base64.b64encode(b'1234567887654321').decode('ascii')
# key / iv
os.environ['VETURBOIO_KEY'] = base64.b64encode(b'abcdefgh12345678').decode('ascii')
os.environ['VETURBOIO_IV'] = base64.b64encode(b'1234567887654321').decode('ascii')
# kms info
ENV_KMS_HOST = 'VETURBOIO_KMS_HOST'
ENV_KMS_REGION = 'VETURBOIO_KMS_REGION'
ENV_KMS_AK = 'VETURBOIO_KMS_ACCESS_KEY'
ENV_KMS_SK = 'VETURBOIO_KMS_SECRET_KEY'
ENV_KMS_KEYRING = 'VETURBOIO_KMS_KEYRING_NAME'
ENV_KMS_KEY = 'VETURBOIO_KMS_KEY_NAME'
os.environ[ENV_KMS_HOST] = 'open.volcengineapi.com'
os.environ[ENV_KMS_REGION] = 'cn-beijing'
os.environ[ENV_KMS_AK] = os.environ['CI_VENDOR_AK']
os.environ[ENV_KMS_SK] = os.environ['CI_VENDOR_SK']
os.environ[ENV_KMS_KEYRING] = 'datapipe_keyring'
os.environ[ENV_KMS_KEY] = 'datapipe_key_ml_maas'
cls.filepath_0 = "sfcs://model.safetensors"
cls.filepath_1 = "sfcs://model.pt"
# mock /tmp as efs mount path
cls.filepath_2 = "/model.safetensors"
cls.tensors_0 = {
"weight1": torch.ones(50, 50),
"weight2": torch.zeros(50, 50),
"weight1": torch.ones(500, 50),
"weight2": torch.zeros(500, 50),
}
class MockModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(50, 50)
self.linear2 = torch.nn.Linear(50, 50)
self.linear1 = torch.nn.Linear(500, 50)
self.linear2 = torch.nn.Linear(500, 50)
cls.model = MockModel()
......@@ -148,3 +162,14 @@ class TestSFCSLoad(TestCase):
def test_pipeline_cuda(self):
self._run_pipeline(self.cuda_tensors_0, self.cuda_model, "cuda:0", use_cipher=False)
self._run_pipeline(self.cuda_tensors_0, self.cuda_model, "cuda:0", use_cipher=True)
def test_pipeline_cipher_header_cpu(self):
os.environ["VETURBOIO_CIPHER_HEADER"] = "1"
self._run_pipeline(self.tensors_0, self.model, "cpu", use_cipher=True)
del os.environ["VETURBOIO_CIPHER_HEADER"]
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_pipeline_cipher_header_cuda(self):
os.environ["VETURBOIO_CIPHER_HEADER"] = "1"
self._run_pipeline(self.cuda_tensors_0, self.cuda_model, "cuda:0", use_cipher=True)
del os.environ["VETURBOIO_CIPHER_HEADER"]
......@@ -58,7 +58,7 @@ def load(
use_pinmem (bool, optional): use pin memory. Defaults to False.
num_thread (int, optional): number of threads. Defaults to 32.
use_direct_io (bool, optional): open file in direct io mode. Defaults to False.
use_cipher (bool, optional): decrypt file when use sfcs sdk. Defaults to False.
use_cipher (bool, optional): decrypt file. Defaults to False.
Returns:
state_dict (Dict): state dict
......@@ -84,7 +84,6 @@ def load(
num_thread=num_thread,
use_pinmem=use_pinmem,
use_direct_io=use_direct_io,
use_cipher=use_cipher,
)
else:
loader = FasterPosixLoader(
......@@ -94,7 +93,7 @@ def load(
use_direct_io=use_direct_io,
)
safetensors_file = SafetensorsFile(file, loader)
safetensors_file = SafetensorsFile(file, loader, use_cipher)
return safetensors_file.load(map_location=map_location)
......@@ -114,7 +113,7 @@ def save_file(
force_contiguous (bool, optional): force contiguous. Defaults to True.
force_save_shared_tensor (bool, optional): force save shared tensor. Defaults to False.
metadata (Dict[str, str], optional): metadata. Defaults to None.
use_cipher (bool, optional): decrypt file when use sfcs sdk. Defaults to False.
use_cipher (bool, optional): decrypt file. Defaults to False.
Examples:
```
......@@ -129,7 +128,7 @@ def save_file(
if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher)
else:
saver = PosixSaver()
saver = PosixSaver(use_cipher=use_cipher)
# TODO: there are some bugs while state_dict is loaded from veturboio
if not force_save_shared_tensor:
......@@ -164,7 +163,7 @@ def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[boo
Args:
model (torch.nn.Module): model
file (FILE_PATH): file path
use_cipher (bool, optional): decrypt file when use sfcs sdk. Defaults to False.
use_cipher (bool, optional): decrypt file. Defaults to False.
Examples:
```
......@@ -180,7 +179,7 @@ def save_model(model: torch.nn.Module, file: FILE_PATH, use_cipher: Optional[boo
if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher)
else:
saver = PosixSaver()
saver = PosixSaver(use_cipher=use_cipher)
return saver.save_model(model, file)
......@@ -191,7 +190,7 @@ def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Op
Args:
state_dict (Dict): state dict
file (FILE_PATH): file path
use_cipher (bool, optional): encrypt file when use sfcs sdk. Defaults to False.
use_cipher (bool, optional): encrypt file. Defaults to False.
Examples:
```
......@@ -206,6 +205,6 @@ def save_pt(state_dict: Dict[str, torch.Tensor], file: FILE_PATH, use_cipher: Op
if use_sfcs_sdk:
saver = SfcsClientSaver(use_cipher=use_cipher)
else:
saver = PosixSaver()
saver = PosixSaver(use_cipher=use_cipher)
return saver.save_pt(state_dict, file)
......@@ -14,12 +14,15 @@ See the License for the specific language governing permissions and
limitations under the License.
'''
import io
from typing import Any, Dict
import numpy as np
import torch
from numpy import ndarray
from veturboio.ops.cipher import CipherInfo, decrypt
# from veturboio.safetensors import SafetensorsFile
from veturboio.types import FILE_PATH
......@@ -31,7 +34,9 @@ class BaseLoader:
def __init__(self, method: str) -> None:
self.method = method
def load_to_bytes_array(self, file: FILE_PATH, offset: int, count: int) -> ndarray:
def load_to_bytes_array(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray:
raise NotImplementedError
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
......@@ -63,8 +68,14 @@ class PosixLoader(BaseLoader):
def __init__(self) -> None:
super().__init__(method="posix")
def load_to_bytes_array(self, file: FILE_PATH, offset: int, count: int) -> ndarray:
return np.fromfile(file, dtype=np.byte, offset=offset, count=count)
def load_to_bytes_array(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray:
arr = np.fromfile(file, dtype=np.uint8, offset=offset, count=count)
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
decrypt(cipher_info, arr, arr, offset - h_off)
return arr
def load_safetensors(self, safetensors_file: Any, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
state_dict = {}
......@@ -72,14 +83,20 @@ class PosixLoader(BaseLoader):
base_offset = safetensors_file.tensor_offset
device = torch.device(map_location)
cipher_info = safetensors_file._cipher_info
mp_mode = "c" if cipher_info.use_cipher else "r"
for tensor_meta in safetensors_file.meta.values():
tensor_bytes = np.memmap(
safetensors_file.file,
dtype=np.byte,
mode="r",
dtype=np.uint8,
mode=mp_mode,
offset=base_offset + tensor_meta.data_offsets[0],
shape=tensor_meta.data_offsets[1] - tensor_meta.data_offsets[0],
)
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
decrypt(cipher_info, tensor_bytes, tensor_bytes, base_offset + tensor_meta.data_offsets[0] - h_off)
tensor = torch.frombuffer(tensor_bytes, dtype=tensor_meta.dtype)
tensor = tensor.view(tensor_meta.shape)
if device.type == "cuda":
......@@ -89,5 +106,13 @@ class PosixLoader(BaseLoader):
return state_dict
def load_pt(self, file: FILE_PATH, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]:
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
arr = np.fromfile(file, dtype=np.uint8, offset=h_off, count=-1)
decrypt(cipher_info, arr, arr, 0)
return torch.load(io.BytesIO(arr.data), map_location=map_location)
return torch.load(file, map_location=map_location)
......@@ -14,11 +14,14 @@ See the License for the specific language governing permissions and
limitations under the License.
'''
import io
import os
from typing import Dict
import numpy as np
import torch
from veturboio.ops.cipher import CipherInfo, decrypt
from veturboio.ops.load_utils import IOHelper, load_file_to_tensor
from veturboio.safetensors import SafetensorsFile
from veturboio.types import FILE_PATH
......@@ -46,7 +49,6 @@ class FasterPosixLoader(PosixLoader):
file_size = os.path.getsize(safetensors_file.file)
base_offset = safetensors_file.tensor_offset
device = torch.device(map_location)
if device.type == "cuda":
device_id = device.index if device.index is not None else torch.cuda.current_device()
else:
......@@ -64,9 +66,18 @@ class FasterPosixLoader(PosixLoader):
use_pinmem=self.use_pinmem,
use_sfcs_sdk=False,
use_direct_io=self.use_direct_io,
cipher_info=safetensors_file._cipher_info,
)
return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file)
def load_pt(self, file: FILE_PATH, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]:
if cipher_info.use_cipher:
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
arr = np.fromfile(file, dtype=np.uint8, offset=h_off, count=-1)
decrypt(cipher_info, arr, arr, 0)
return torch.load(io.BytesIO(arr.data), map_location=map_location)
return torch.load(file, map_location=map_location)
......@@ -37,7 +37,6 @@ class SfcsClientLoader(BaseLoader):
num_thread: int = 32,
use_pinmem: bool = False,
use_direct_io: bool = False,
use_cipher: bool = False,
) -> None:
super().__init__(method="client")
......@@ -46,15 +45,17 @@ class SfcsClientLoader(BaseLoader):
self.use_pinmem = use_pinmem
self.use_direct_io = use_direct_io
use_cipher = use_cipher or os.environ.get("VETURBOIO_USE_CIPHER", "0") == "1"
self.cipher_info = CipherInfo(use_cipher)
init_sfcs_conf()
def load_to_bytes_array(self, file: FILE_PATH, offset: int, count: int) -> ndarray:
def load_to_bytes_array(
self, file: FILE_PATH, offset: int, count: int, cipher_info: CipherInfo = CipherInfo(False)
) -> ndarray:
file_size = sfcs_get_file_size(file)
if offset + count > file_size:
count = file_size - offset
candidate = np.empty([count], dtype=np.byte)
sfcs_read_file(
file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=self.cipher_info
file, candidate, length=count, offset=offset, num_thread=self.num_thread, cipher_info=cipher_info
)
return candidate
......@@ -80,12 +81,15 @@ class SfcsClientLoader(BaseLoader):
use_pinmem=self.use_pinmem,
use_sfcs_sdk=True,
use_direct_io=self.use_direct_io,
cipher_info=self.cipher_info,
cipher_info=safetensors_file._cipher_info,
)
return SafetensorsFile.split_tensor_to_state_dict(total_tensor, safetensors_file)
def load_pt(self, file: FILE_PATH, map_location: str = "cpu") -> Dict[str, torch.Tensor]:
def load_pt(
self, file: FILE_PATH, map_location: str = "cpu", cipher_info: CipherInfo = CipherInfo(False)
) -> Dict[str, torch.Tensor]:
file_size = sfcs_get_file_size(file)
file_bytes = self.load_to_bytes_array(file, offset=0, count=file_size).tobytes()
return torch.load(BytesIO(file_bytes), map_location=map_location)
h_off = CipherInfo.HEADER_SIZE if cipher_info.use_header else 0
file_bytes = self.load_to_bytes_array(file, offset=h_off, count=file_size - h_off, cipher_info=cipher_info)
return torch.load(BytesIO(file_bytes.data), map_location=map_location)
......@@ -15,118 +15,412 @@ limitations under the License.
'''
import base64
import hashlib
import hmac
import json
import os
import threading
import urllib.parse
from datetime import datetime, timezone
import secrets
import socket
from datetime import datetime
from enum import Enum
from time import sleep
from typing import Optional, Tuple
import numpy as np
import requests_unixsocket
import requests
from loguru import logger
from requests.adapters import HTTPAdapter
from urllib3.connection import HTTPConnection
from urllib3.connectionpool import HTTPConnectionPool
try:
import veturboio_ext
CtrEncWrap = veturboio_ext.CtrEncWrap
CtrDecWrap = veturboio_ext.CtrDecWrap
except ImportError:
CtrEncWrap = None
CtrDecWrap = None
logger.warning("veturboio_ext not found, fallback to pure python implementation")
class SnapdConnection(HTTPConnection):
def __init__(self, uds_path):
super().__init__("localhost")
self.uds_path = uds_path
def connect(self):
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self.uds_path)
class SnapdConnectionPool(HTTPConnectionPool):
def __init__(self, uds_path):
super().__init__("localhost")
self.uds_path = uds_path
def _new_conn(self):
return SnapdConnection(self.uds_path)
class SnapdAdapter(HTTPAdapter):
def __init__(self, uds_path):
super().__init__()
self.uds_path = uds_path
def get_connection(self, url, proxies=None):
return SnapdConnectionPool(self.uds_path)
class DataPipeClient:
DATAPIPE_SOCKET_PATH = os.getenv('DATAPIPE_SOCKET_PATH', '/finetune/data/datapipe.sock')
DATAPIPE_SOCKET_PATH = os.getenv('DATAPIPE_SOCKET_PATH', '/finetuned-model/datapipe.sock')
PING_HEADER = {'X-Datapipe-Task-Type': 'ping'}
ENCRYPT_HEADER = {'X-Datapipe-Task-Type': 'encrypt-key'}
SFCS_STS_HEADER = {'X-Datapipe-Task-Type': 'sfcs-sts'}
KMS_STS_HEADER = {'X-Datapipe-Task-Type': 'kms-sts'}
def __init__(self, retry: int = 3, interval: float = 0.5) -> None:
if os.path.exists(self.DATAPIPE_SOCKET_PATH):
self.url = 'http+unix://' + urllib.parse.quote(self.DATAPIPE_SOCKET_PATH, safe='')
self.session = requests_unixsocket.Session()
self.retry = retry
self.interval = interval
else:
self.url = None
self.session = None
def get_data_key_iv(self) -> Tuple[Optional[str], Optional[str]]:
if not self.session:
logger.warning('Datapipe client initialization failed')
return None, None
if not os.path.exists(self.DATAPIPE_SOCKET_PATH):
raise RuntimeError(f'Datapipe socket {self.DATAPIPE_SOCKET_PATH} does not exist')
self.url = 'http://localhost'
self.session = requests.Session()
self.session.mount(self.url, SnapdAdapter(self.DATAPIPE_SOCKET_PATH))
self.retry = retry
self.interval = interval
resp = self._get_retry(self.PING_HEADER)
if resp is None or resp['message'] != 'pong':
raise RuntimeError(f'Ping Datapipe socket {self.DATAPIPE_SOCKET_PATH} failed')
def _get_retry(self, headers: dict) -> Optional[dict]:
re = 0
while True:
try:
response = self.session.get(self.url, headers=self.ENCRYPT_HEADER)
response = self.session.get(self.url, headers=headers)
if response.status_code == 200:
res = response.json()
return res['Key'], res['IV']
return response.json()
except Exception as e:
logger.warning(e)
logger.warning(f'call with {headers} return err: {e}')
if re > self.retry:
break
sleep(self.interval)
re += 1
return None, None
return None
def get_data_key_iv(self) -> Optional[dict]:
return self._get_retry(self.ENCRYPT_HEADER)
def get_sfcs_ak_sk_st(self) -> Optional[dict]:
if not self.session:
logger.warning('Datapipe client initialization failed')
return None
return self._get_retry(self.SFCS_STS_HEADER)
re = 0
while True:
try:
response = self.session.get(self.url, headers=self.SFCS_STS_HEADER)
if response.status_code == 200:
return response.json()
except Exception as e:
logger.warning(e)
def get_kms_ak_sk_st(self) -> Optional[dict]:
return self._get_retry(self.KMS_STS_HEADER)
if re > self.retry:
break
sleep(self.interval)
re += 1
return None
class KmsService:
SERVICE = 'kms'
def __init__(
self,
ak: str,
sk: str,
keyring_name: str,
key_name: str,
region: Optional[str] = None,
host: Optional[str] = None,
st: Optional[str] = None,
uds_proxy: Optional[str] = None,
) -> None:
self._ak = ak
self._sk = sk
self._st = st
self._keyring_name = keyring_name
self._key_name = key_name
self._host = host or 'open.volcengineapi.com'
self._region = region or 'cn-beijing'
self._uds_proxy = uds_proxy
@staticmethod
def sign(key: bytes, msg: str):
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
@staticmethod
def getSignatureKey(key: str, dateStamp: str, regionName: str, serviceName: str):
kDate = KmsService.sign(key.encode('utf-8'), dateStamp)
kRegion = KmsService.sign(kDate, regionName)
kService = KmsService.sign(kRegion, serviceName)
kSigning = KmsService.sign(kService, 'request')
return kSigning
@staticmethod
def formatParameters(parameters: dict):
request_parameters_init = ''
for key in sorted(parameters):
request_parameters_init += key + '=' + parameters[key] + '&'
request_parameters = request_parameters_init[:-1]
return request_parameters
@staticmethod
def sigv4(
ak: str,
sk: str,
host: str,
region: str,
srv: str,
method: str,
params: dict,
payload: dict,
st: Optional[str] = None,
uds_proxy: Optional[str] = None,
) -> requests.Response:
now = datetime.utcnow()
current_date = now.strftime('%Y%m%dT%H%M%SZ')
datestamp = now.strftime('%Y%m%d')
cano_uri = '/'
cano_query = KmsService.formatParameters(params)
signed_headers = 'content-type;host;x-content-sha256;x-date'
payload = json.dumps(payload)
payload_hash = hashlib.sha256(payload.encode('utf-8')).hexdigest()
cano_headers = (
f'content-type:application/json\nhost:{host}\nx-content-sha256:{payload_hash}\nx-date:{current_date}\n'
)
cano_request = f'{method}\n{cano_uri}\n{cano_query}\n{cano_headers}\n{signed_headers}\n{payload_hash}'
algorithm = 'HMAC-SHA256'
cred_scope = f'{datestamp}/{region}/{srv}/request'
string_to_sign = (
f'{algorithm}\n{current_date}\n{cred_scope}\n' + hashlib.sha256(cano_request.encode('utf-8')).hexdigest()
)
signing_key = KmsService.getSignatureKey(sk, datestamp, region, srv)
signature = hmac.new(signing_key, (string_to_sign).encode('utf-8'), hashlib.sha256).hexdigest()
authorization_header = (
f'{algorithm} Credential={ak}/{cred_scope}, SignedHeaders={signed_headers}, Signature={signature}'
)
headers = {
'X-Date': current_date,
'Authorization': authorization_header,
'X-Content-Sha256': payload_hash,
'Content-Type': 'application/json',
'X-Amz-Date': '20180614T114308Z',
}
if st:
headers['X-Security-Token'] = st
request_url = f'https://{host}?{cano_query}'
session = requests.Session()
if uds_proxy:
session.mount(f'https://{host}', SnapdAdapter(uds_proxy))
headers['X-Datapipe-Task-Type'] = 'top'
resp = session.post(request_url, data=payload, headers=headers)
return resp
def encrypt(self, pt_b64: str) -> str:
params = {
'Action': 'Encrypt',
'Version': '2021-02-18',
'KeyringName': self._keyring_name,
'KeyName': self._key_name,
}
payload = {'Plaintext': pt_b64}
resp = KmsService.sigv4(
self._ak,
self._sk,
self._host,
self._region,
self.SERVICE,
'POST',
params,
payload,
self._st,
self._uds_proxy,
)
if resp.status_code == 200:
j = resp.json()
if 'Result' in j:
return resp.json()['Result']['CiphertextBlob']
raise RuntimeError(f'kms encrypt failed: {resp.text}')
def decrypt(self, ct_b64: str) -> str:
params = {
'Action': 'Decrypt',
'Version': '2021-02-18',
'KeyringName': self._keyring_name,
'KeyName': self._key_name,
}
payload = {'CiphertextBlob': ct_b64}
resp = KmsService.sigv4(
self._ak,
self._sk,
self._host,
self._region,
self.SERVICE,
'POST',
params,
payload,
self._st,
self._uds_proxy,
)
if resp.status_code == 200:
j = resp.json()
if 'Result' in j:
pt_b64 = resp.json()['Result']['Plaintext']
return pt_b64
raise RuntimeError(f'kms decrypt failed: {resp.text}')
class CipherMode(Enum):
CTR_256 = 'CTR-256'
CTR_128 = 'CTR-128'
class CipherInfo:
ENV_KEY = 'VETUROIO_KEY'
ENV_IV = 'VETUROIO_IV'
ENV_KEY = 'VETURBOIO_KEY'
ENV_IV = 'VETURBOIO_IV'
ENV_KMS_HOST = 'VETURBOIO_KMS_HOST'
ENV_KMS_REGION = 'VETURBOIO_KMS_REGION'
ENV_KMS_AK = 'VETURBOIO_KMS_ACCESS_KEY'
ENV_KMS_SK = 'VETURBOIO_KMS_SECRET_KEY'
ENV_KMS_ST = 'VETURBOIO_KMS_SESSION_TOKEN'
ENV_KMS_KEYRING = 'VETURBOIO_KMS_KEYRING_NAME'
ENV_KMS_KEY = 'VETURBOIO_KMS_KEY_NAME'
HEADER_SIZE = 262144
MAGIC_NUMBER = b'Byte3ncryptM0del'
def __init__(self, use_cipher: bool, header_bytes: Optional[bytes] = None) -> None:
self.use_cipher = use_cipher
self.use_header = False
self.mode = CipherMode.CTR_128
self.key = np.frombuffer(b'\x00' * 16, dtype=np.byte)
self.iv = np.frombuffer(b'\x00' * 16, dtype=np.byte)
if not use_cipher:
return
# case 1: get key and iv from file header part
if (
header_bytes is not None
and len(header_bytes) == self.HEADER_SIZE
and header_bytes[:16] == self.MAGIC_NUMBER
):
# parse header to get key and iv
self.use_header = True
try:
kms_srv = self.fetch_kms_client()
first_zero = header_bytes.index(0)
header_dict = json.loads(header_bytes[16:first_zero])
self.mode = CipherMode(header_dict['mode'])
key_b64 = kms_srv.decrypt(header_dict['meta_data_key'])
iv_b64 = header_dict['iv']
self.key, self.iv = self.convert_key_iv(key_b64, iv_b64)
logger.info('get cipher info from file header successfully!')
return
except Exception as e:
logger.warning(f'get cipher info from file header failed: {e}')
def __init__(self, use_cipher: bool) -> None:
if use_cipher:
# first try to get key and iv from datapipe
# case 2: get key and iv from datapipe uds
try:
client = DataPipeClient()
if client.session:
try:
key_b64, iv_b64 = client.get_data_key_iv()
self.key, self.iv = self.convert_key_iv(key_b64, iv_b64)
self.use_cipher = True
logger.info('get cipher info from datapipe socket')
return
except Exception as e:
logger.warning(e)
# then try to get key and iv from env
env_key = os.getenv(self.ENV_KEY)
env_iv = os.getenv(self.ENV_IV)
if env_key and env_iv:
try:
self.key, self.iv = self.convert_key_iv(env_key, env_iv)
self.use_cipher = True
logger.info('get cipher info from env')
return
except Exception as e:
logger.warning(e)
logger.warning('fail to get key and iv, fallback to no cipher')
resp = client.get_data_key_iv()
self.key, self.iv = self.convert_key_iv(resp['Key'], resp['IV'])
logger.info('get cipher info from datapipe uds successfully!')
return
except Exception as e:
logger.warning(f'get cipher info from datapipe uds failed: {e}')
# case 3: get key and iv from env
try:
for e in [self.ENV_KEY, self.ENV_IV]:
assert e in os.environ, f'env {e} not set'
self.key, self.iv = self.convert_key_iv(os.getenv(self.ENV_KEY), os.getenv(self.ENV_IV))
logger.info('get cipher info from env')
return
except Exception as e:
logger.warning(f'get cipher info from env failed :{e}')
# fallback to no cipher
self.use_cipher = False
self.key = np.frombuffer(b'\x00' * 16, dtype=np.byte)
self.iv = np.frombuffer(b'\x00' * 16, dtype=np.byte)
logger.warning('fail to get key and iv, fallback to no cipher')
@staticmethod
def convert_key_iv(key_b64: str, iv_b64: str) -> Tuple[np.ndarray, np.ndarray]:
key_b = base64.b64decode(key_b64, validate=True)
iv_b = base64.b64decode(iv_b64, validate=True)
if len(key_b) != 16 or len(iv_b) != 16:
raise Exception('length of key or iv is not 16')
if (len(key_b) != 16 and len(key_b) != 32) or len(iv_b) != 16:
raise Exception(f'length of key {len(key_b)} or iv {len(iv_b)} is not valid')
key = np.frombuffer(key_b, dtype=np.byte)
iv = np.frombuffer(iv_b, dtype=np.byte)
return key, iv
def fetch_kms_client(self) -> KmsService:
kms_host = os.getenv(self.ENV_KMS_HOST)
region = os.getenv(self.ENV_KMS_REGION)
ak = os.getenv(self.ENV_KMS_AK)
sk = os.getenv(self.ENV_KMS_SK)
st = os.getenv(self.ENV_KMS_ST)
keyring_name = os.getenv(self.ENV_KMS_KEYRING)
key_name = os.getenv(self.ENV_KMS_KEY)
uds_proxy = None
# try to fetch kms credential from datapipe
if os.path.exists(DataPipeClient.DATAPIPE_SOCKET_PATH):
try:
client = DataPipeClient()
uds_proxy = DataPipeClient.DATAPIPE_SOCKET_PATH
resp = client.get_kms_ak_sk_st()
ak = resp['Cred']['AccessKeyId']
sk = resp['Cred']['SecretAccessKey']
st = resp['Cred']['SessionToken']
logger.info('get kms credential from datapipe successfully!')
except Exception as e:
logger.warning(f'get kms ak/sk/st from datapipe failed: {e}')
for var in [ak, sk, keyring_name, key_name]:
assert var is not None, 'required kms info not set'
return KmsService(ak, sk, keyring_name, key_name, region, kms_host, st, uds_proxy)
def to_header_bytes(self) -> bytearray:
kms_srv = self.fetch_kms_client()
header_dict = {
'mode': self.mode.value,
'iv': base64.b64encode(self.iv.data).decode('utf-8'),
'meta_data_key': kms_srv.encrypt(base64.b64encode(self.key).decode('utf-8')),
'file_timestamp': int(datetime.utcnow().timestamp()),
}
header_json = json.dumps(header_dict)
header_bytes = bytearray(self.HEADER_SIZE)
header_bytes[:16] = self.MAGIC_NUMBER
header_bytes[16 : 16 + len(header_json)] = header_json.encode('utf-8')
return header_bytes
def create_cipher_with_header(mode: CipherMode) -> CipherInfo:
c = CipherInfo(False)
c.use_cipher = True
c.use_header = True
c.mode = mode
if c.mode == CipherMode.CTR_256:
key_bytes = secrets.token_bytes(32)
else:
key_bytes = secrets.token_bytes(16)
iv_bytes = secrets.token_bytes(16)
c.key = np.frombuffer(key_bytes, dtype=np.byte)
c.iv = np.frombuffer(iv_bytes, dtype=np.byte)
return c
def encrypt(cipher_info: CipherInfo, pt: np.ndarray, ct: np.ndarray, offset: int):
if not cipher_info.use_cipher:
logger.warning('cipher.encrypt: use_cipher False, skip')
return
enc = CtrEncWrap(cipher_info.mode.value, cipher_info.key, cipher_info.iv, offset)
ret = enc.encrypt_update(pt, ct)
if not ret:
logger.error('cipher.encrypt: failed')
def decrypt(cipher_info: CipherInfo, ct: np.ndarray, pt: np.ndarray, offset: int):
if not cipher_info.use_cipher:
logger.warning('cipher.decrypt: use_cipher False, skip')
return
dec = CtrDecWrap(cipher_info.mode.value, cipher_info.key, cipher_info.iv, offset)
ret = dec.decrypt_update(ct, pt)
if not ret:
logger.error('cipher.decrypt: failed')
/*
* Copyright (c) 2024 Beijing Volcano Engine Technology Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "include/cipher.h"
#include <iostream>
CipherInfo::CipherInfo(bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr,
size_t header_size)
: use_cipher(use_cipher), header_size(header_size)
{
if (use_cipher)
{
pybind11::buffer_info key_info = key_arr.request();
size_t key_size = key_info.size;
if (key_size == 16)
{
mode = "CTR-128";
}
else if (key_size == 32)
{
mode = "CTR-256";
}
else
{
throw std::runtime_error("Cipher Exception: key length invalid");
}
key = reinterpret_cast<unsigned char *>(key_info.ptr);
pybind11::buffer_info iv_info = iv_arr.request();
if ((size_t)iv_info.size != AES_BLOCK_SIZE)
{
throw std::runtime_error("Cipher Exception: iv length invalid");
}
iv = reinterpret_cast<unsigned char *>(iv_info.ptr);
}
}
CtrEncWrap::CtrEncWrap(std::string mode, pybind11::array_t<unsigned char> key_arr,
pybind11::array_t<unsigned char> iv_arr, size_t global_offset)
{
pybind11::buffer_info key_info = key_arr.request();
pybind11::buffer_info iv_info = iv_arr.request();
enc_.reset(new CtrEncrypter(mode, (unsigned char *)key_info.ptr, (unsigned char *)iv_info.ptr, global_offset));
}
size_t CtrEncWrap::encrypt_update(pybind11::array_t<unsigned char> pt, pybind11::array_t<unsigned char> ct)
{
pybind11::buffer_info pt_info = pt.request();
pybind11::buffer_info ct_info = ct.request();
unsigned char *pt_ptr = (unsigned char *)pt_info.ptr;
unsigned char *ct_ptr = (unsigned char *)ct_info.ptr;
return enc_->encrypt_update(pt_ptr, pt_info.size, ct_ptr);
}
CtrDecWrap::CtrDecWrap(std::string mode, pybind11::array_t<unsigned char> key_arr,
pybind11::array_t<unsigned char> iv_arr, size_t global_offset)
{
pybind11::buffer_info key_info = key_arr.request();
pybind11::buffer_info iv_info = iv_arr.request();
dec_.reset(new CtrDecrypter(mode, (unsigned char *)key_info.ptr, (unsigned char *)iv_info.ptr, global_offset));
}
size_t CtrDecWrap::decrypt_update(pybind11::array_t<unsigned char> ct, pybind11::array_t<unsigned char> pt)
{
pybind11::buffer_info pt_info = pt.request();
pybind11::buffer_info ct_info = ct.request();
unsigned char *pt_ptr = (unsigned char *)pt_info.ptr;
unsigned char *ct_ptr = (unsigned char *)ct_info.ptr;
return dec_->decrypt_update(ct_ptr, ct_info.size, pt_ptr);
}
/*
* Copyright (c) 2024 Beijing Volcano Engine Technology Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef VETURBOIO_CIPHER_H
#define VETURBOIO_CIPHER_H
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <string>
#include <memory>
#include "fastcrypto.h"
class CipherInfo
{
public:
bool use_cipher = false;
std::string mode = "CTR-128";
size_t header_size = 0;
unsigned char *key = NULL;
unsigned char *iv = NULL;
CipherInfo(bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr, size_t header_size);
CipherInfo() = default;
};
class CtrEncWrap
{
private:
std::unique_ptr<CtrEncrypter> enc_;
public:
CtrEncWrap() = default;
CtrEncWrap(std::string mode, pybind11::array_t<unsigned char> key_arr, pybind11::array_t<unsigned char> iv_arr,
size_t global_offset);
size_t encrypt_update(pybind11::array_t<unsigned char> pt, pybind11::array_t<unsigned char> ct);
};
class CtrDecWrap
{
private:
std::unique_ptr<CtrDecrypter> dec_;
public:
CtrDecWrap() = default;
CtrDecWrap(std::string mode, pybind11::array_t<unsigned char> key_arr, pybind11::array_t<unsigned char> iv_arr,
size_t global_offset);
size_t decrypt_update(pybind11::array_t<unsigned char> ct, pybind11::array_t<unsigned char> pt);
};
#endif
\ No newline at end of file
......@@ -13,18 +13,33 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef AES_CPU_CTR_H
#define AES_CPU_CTR_H
#ifndef VETURBOIO_FASTCRYPTO_H
#define VETURBOIO_FASTCRYPTO_H
#include <stdio.h>
#include <string>
extern const size_t EVP_UPDATE_MAX;
extern const size_t CTR_BLOCK_SIZE;
#define EVP_UPDATE_MAX 0x7ffffff0
#define AES_BLOCK_SIZE 16
#define AES_BUF_MAX_SIZE 32
#define MAX_CTR_KEY_SIZE 32
#define FASTCRYPTO_MAGIC_SIZE 16
void ctr128_inc_by(unsigned char *counter, size_t n, size_t c);
inline void counter_inc_by(unsigned char *counter, size_t n, size_t c)
{
do
{
--n;
c += counter[n];
counter[n] = static_cast<unsigned char>(c);
c >>= 8;
} while (n);
}
typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX;
typedef struct evp_cipher_st EVP_CIPHER;
typedef struct evp_mac_ctx_st EVP_MAC_CTX;
typedef struct evp_mac_st EVP_MAC;
class CtrEncrypter
{
......@@ -33,7 +48,8 @@ class CtrEncrypter
EVP_CIPHER *cipher = NULL;
public:
CtrEncrypter(const unsigned char *key, const unsigned char *iv, size_t global_offset);
CtrEncrypter() = default;
CtrEncrypter(std::string algo, const unsigned char *key, const unsigned char *iv, size_t global_offset);
~CtrEncrypter();
int encrypt_update(unsigned char *pt, size_t pt_size, unsigned char *ct);
};
......@@ -45,23 +61,16 @@ class CtrDecrypter
EVP_CIPHER *cipher = NULL;
public:
CtrDecrypter(const unsigned char *key, const unsigned char *iv, size_t global_offset);
CtrDecrypter() = default;
CtrDecrypter(std::string algo, const unsigned char *key, const unsigned char *iv, size_t global_offset);
~CtrDecrypter();
int decrypt_update(unsigned char *ct, size_t ct_size, unsigned char *pt);
};
#endif
#ifndef AES_GPU_CTR_H
#define AES_GPU_CTR_H
#include <stdio.h>
// Both encrypt and decrypt require length of ct and pt multiple of 16
int ctr_encrypt_gpu(std::string algo, const unsigned char *key, const unsigned char *iv, unsigned char *pt,
size_t pt_size, unsigned char *ct);
int ctr_encrypt_gpu(const unsigned char *key, const unsigned char *iv, unsigned char *pt, size_t pt_size,
unsigned char *ct);
int ctr_decrypt_gpu(const unsigned char *key, const unsigned char *iv, unsigned char *ct, size_t ct_size,
unsigned char *pt);
int ctr_decrypt_gpu(std::string algo, const unsigned char *key, const unsigned char *iv, unsigned char *ct,
size_t ct_size, unsigned char *pt);
#endif
\ No newline at end of file
......@@ -30,7 +30,7 @@ class IOHelper
void load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, torch::Tensor sample_tensor,
int64_t offset, int64_t device_id, int64_t num_thread, bool use_pinmem, bool use_sfcs_sdk,
bool use_direct_io, bool use_cipher, pybind11::array_t<char> key_arr,
pybind11::array_t<char> iv_arr);
pybind11::array_t<char> iv_arr, int64_t header_size);
void init_buffer(string file_path, int64_t file_size, bool use_pinmem, bool use_sfcs_sdk);
void free_buffer();
};
......
......@@ -17,6 +17,7 @@
#define LOAD_UTILS_H
#include "common.h"
#include "cipher.h"
void read_file(string file_path, char *addr, char *dev_mem, int num_thread, size_t total_size, size_t global_offset,
bool use_sfcs_sdk, bool use_direct_io, CipherInfo cipher_info);
......
......@@ -22,22 +22,13 @@
#include "common.h"
#include "cfs.h"
#include "logging.h"
#include "cipher.h"
#define SFCS_NAME_NODE "default"
#define SFCS_USER_NAME "demo-user"
using namespace std;
class CipherInfo
{
public:
bool use_cipher = false;
unsigned char *key = NULL;
unsigned char *iv = NULL;
CipherInfo(bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr);
CipherInfo(){};
};
class SFCSFile
{
public:
......@@ -47,7 +38,8 @@ class SFCSFile
CipherInfo cipher_info;
SFCSFile(std::string file_path);
SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr);
SFCSFile(std::string file_path, bool use_cipher, pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr,
size_t header_size);
SFCSFile(std::string file_path, CipherInfo cipher_info);
~SFCSFile();
size_t get_file_size();
......
......@@ -13,8 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/fastcrypto.h"
#include "include/io_helper.h"
#include "include/cipher.h"
#include "include/fastcrypto.h"
IOHelper::~IOHelper()
{
......@@ -119,13 +120,13 @@ void read_unaligned_part(std::string file_path, torch::Tensor res_tensor, int64_
void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tensor, torch::Tensor sample_tensor,
int64_t offset, int64_t device_id, int64_t num_thread, bool use_pinmem,
bool use_sfcs_sdk, bool use_direct_io, bool use_cipher,
pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr)
pybind11::array_t<char> key_arr, pybind11::array_t<char> iv_arr, int64_t header_size)
{
size_t file_size = get_file_size(file_path.c_str(), use_sfcs_sdk);
size_t total_size = file_size - offset;
size_t read_unaligned_size = 0;
// set cipher
CipherInfo cipher_info(use_cipher, key_arr, iv_arr);
CipherInfo cipher_info(use_cipher, key_arr, iv_arr, header_size);
if (device_id < 0)
{
read_unaligned_part(file_path, res_tensor, &offset, device_id, &total_size, use_sfcs_sdk, use_direct_io,
......@@ -155,23 +156,24 @@ void IOHelper::load_file_to_tensor(std::string file_path, torch::Tensor res_tens
read_file(file_path, pin_mem, (char *)res_tensor.data_ptr() + read_unaligned_size, num_thread, total_size,
offset, use_sfcs_sdk, use_direct_io, CipherInfo());
cudaDeviceSynchronize();
// decrypt with gpu
if (cipher_info.use_cipher && total_size > 0)
{
if (offset % CTR_BLOCK_SIZE != 0 || total_size % CTR_BLOCK_SIZE != 0)
if (offset % AES_BLOCK_SIZE != 0 || total_size % AES_BLOCK_SIZE != 0)
{
throw std::runtime_error("cannot decrypt because gpu read is not aligned");
}
unsigned char iv[CTR_BLOCK_SIZE];
for (size_t i = 0; i < CTR_BLOCK_SIZE; i++)
unsigned char iv[AES_BLOCK_SIZE];
for (size_t i = 0; i < AES_BLOCK_SIZE; i++)
{
iv[i] = cipher_info.iv[i];
}
ctr128_inc_by(iv, CTR_BLOCK_SIZE, offset / CTR_BLOCK_SIZE);
counter_inc_by(iv, AES_BLOCK_SIZE, (offset - cipher_info.header_size) / AES_BLOCK_SIZE);
unsigned char *iv_gpu;
cudaMalloc((void **)&iv_gpu, CTR_BLOCK_SIZE);
cudaMemcpy(iv_gpu, iv, CTR_BLOCK_SIZE, cudaMemcpyHostToDevice);
cudaMalloc((void **)&iv_gpu, AES_BLOCK_SIZE);
cudaMemcpy(iv_gpu, iv, AES_BLOCK_SIZE, cudaMemcpyHostToDevice);
unsigned char *ct = reinterpret_cast<unsigned char *>(res_tensor.data_ptr()) + read_unaligned_size;
int cipher_ret = ctr_decrypt_gpu(cipher_info.key, iv_gpu, ct, total_size, ct);
int cipher_ret = ctr_decrypt_gpu(cipher_info.mode, cipher_info.key, iv_gpu, ct, total_size, ct);
if (!cipher_ret)
{
throw std::runtime_error("Cipher Exception: gpu decrypt fail");
......
libfastcrypto.so.0.1
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment