"...text-generation-inference.git" did not exist on "0759ec495e15a865d2a59befc2b796b5acc09b50"
Commit 8dc44dfb authored by 王孝诚's avatar 王孝诚 Committed by huteng.ht
Browse files

feat: add clone mode for shared tensor

* feat: add clone mode for shared tensor
* feat: add parameter illustration
parent 7feb2732
...@@ -55,9 +55,19 @@ class TestSave(TestCase): ...@@ -55,9 +55,19 @@ class TestSave(TestCase):
def test_save_file(self): def test_save_file(self):
veturboio.save_file(self.tensors_0, self.filepath_0) veturboio.save_file(self.tensors_0, self.filepath_0)
with safe_open(self.filepath_0, framework="pt", device="cpu") as f: with safe_open(self.filepath_0, framework="pt", device="cpu") as f:
assert len(f.keys()) == 2
for key in f.keys(): for key in f.keys():
self.assertTrue(torch.allclose(self.tensors_0[key], f.get_tensor(key))) self.assertTrue(torch.allclose(self.tensors_0[key], f.get_tensor(key)))
def test_save_file_for_clone_share_tensors(self):
share_dict = {"key1": self.tensors_0["weight1"], "key2": self.tensors_0["weight1"]}
veturboio.save_file(share_dict, self.filepath_0, force_save_shared_tensor=True, force_clone_shared_tensor=True)
assert len(share_dict) == 2 # assert save_file won't change user's state_dict.
with safe_open(self.filepath_0, framework="pt", device="cpu") as f:
for key in f.keys():
assert key in share_dict
self.assertTrue(torch.allclose(share_dict[key], f.get_tensor(key)))
def test_save_model(self): def test_save_model(self):
veturboio.save_model(self.model, self.filepath_3, use_cipher=True) veturboio.save_model(self.model, self.filepath_3, use_cipher=True)
loaded_tensors = veturboio.load(self.filepath_3, map_location="cpu", use_cipher=True) loaded_tensors = veturboio.load(self.filepath_3, map_location="cpu", use_cipher=True)
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
from loguru import logger
from safetensors.torch import _remove_duplicate_names from safetensors.torch import _remove_duplicate_names
from safetensors.torch import save_file as safetenors_save_file from safetensors.torch import save_file as safetenors_save_file
from safetensors.torch import save_model as safetensors_save_model from safetensors.torch import save_model as safetensors_save_model
...@@ -104,6 +105,7 @@ def save_file( ...@@ -104,6 +105,7 @@ def save_file(
file: FILE_PATH, file: FILE_PATH,
force_contiguous: bool = True, force_contiguous: bool = True,
force_save_shared_tensor: bool = False, force_save_shared_tensor: bool = False,
force_clone_shared_tensor: bool = False,
metadata: Dict[str, str] = None, metadata: Dict[str, str] = None,
use_cipher: Optional[bool] = False, use_cipher: Optional[bool] = False,
) -> None: ) -> None:
...@@ -114,6 +116,8 @@ def save_file( ...@@ -114,6 +116,8 @@ def save_file(
file (FILE_PATH): file path file (FILE_PATH): file path
force_contiguous (bool, optional): force contiguous. Defaults to True. force_contiguous (bool, optional): force contiguous. Defaults to True.
force_save_shared_tensor (bool, optional): force save shared tensor. Defaults to False. force_save_shared_tensor (bool, optional): force save shared tensor. Defaults to False.
force_clone_shared_tensor (bool, optional): force to clone shared tensor rather than delete
when force_save_shared_tensor is enabled. Defaults to False.
metadata (Dict[str, str], optional): metadata. Defaults to None. metadata (Dict[str, str], optional): metadata. Defaults to None.
use_cipher (bool, optional): decrypt file. Defaults to False. use_cipher (bool, optional): decrypt file. Defaults to False.
...@@ -134,6 +138,8 @@ def save_file( ...@@ -134,6 +138,8 @@ def save_file(
# TODO: there are some bugs while state_dict is loaded from veturboio # TODO: there are some bugs while state_dict is loaded from veturboio
if not force_save_shared_tensor: if not force_save_shared_tensor:
if force_clone_shared_tensor:
logger.warning("force_clone_shared_tensor won't take any effect while force_save_shared_tensor is False;")
try: try:
saver.save_file(state_dict, metadata=metadata) saver.save_file(state_dict, metadata=metadata)
except ValueError as e: except ValueError as e:
...@@ -152,7 +158,10 @@ def save_file( ...@@ -152,7 +158,10 @@ def save_file(
if to_remove not in metadata: if to_remove not in metadata:
# Do not override user data # Do not override user data
metadata[to_remove] = kept_name metadata[to_remove] = kept_name
del state_dict[to_remove] if force_clone_shared_tensor:
state_dict[to_remove] = state_dict[to_remove].clone()
else:
del state_dict[to_remove]
if force_contiguous: if force_contiguous:
state_dict = {k: v.contiguous() for k, v in state_dict.items()} state_dict = {k: v.contiguous() for k, v in state_dict.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment