"git@developer.sourcefind.cn:change/sglang.git" did not exist on "a42736bbb8feeb047c240c062a0ba9de03c2ab40"
Unverified Commit e84f4ba0 authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

[Misc] Fix issues reported by torchfix (#4837)

parent b149b393
...@@ -92,7 +92,7 @@ def convert_bin_to_safetensor_file( ...@@ -92,7 +92,7 @@ def convert_bin_to_safetensor_file(
pt_filename: str, pt_filename: str,
sf_filename: str, sf_filename: str,
) -> None: ) -> None:
loaded = torch.load(pt_filename, map_location="cpu") loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
shared = _shared_pointers(loaded) shared = _shared_pointers(loaded)
...@@ -380,7 +380,7 @@ def np_cache_weights_iterator( ...@@ -380,7 +380,7 @@ def np_cache_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu", weights_only=True)
for name, param in state.items(): for name, param in state.items():
param_path = os.path.join(np_folder, name) param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f: with open(param_path, "wb") as f:
......
...@@ -252,7 +252,7 @@ def resample_patch_embed( ...@@ -252,7 +252,7 @@ def resample_patch_embed(
try: try:
from torch import vmap from torch import vmap
except ImportError: except ImportError:
from functorch import vmap from torch.func import vmap
assert len(patch_embed.shape) == 4, "Four dimensions expected" assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw" assert len(new_size) == 2, "New shape should only be hw"
...@@ -1084,7 +1084,7 @@ def create_siglip_vit( ...@@ -1084,7 +1084,7 @@ def create_siglip_vit(
) )
if ckpt_path: if ckpt_path:
state_dict = torch.load(ckpt_path, map_location="cpu") state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
incompatible_keys = model.load_state_dict(state_dict, strict=False) incompatible_keys = model.load_state_dict(state_dict, strict=False)
print( print(
......
...@@ -586,5 +586,5 @@ def load_token_map(token_map_path: str) -> List[int]: ...@@ -586,5 +586,5 @@ def load_token_map(token_map_path: str) -> List[int]:
ignore_patterns=["*.bin", "*.safetensors"], ignore_patterns=["*.bin", "*.safetensors"],
) )
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path) hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32) return torch.tensor(hot_token_id, dtype=torch.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment