Unverified Commit a7efbb27 authored by Simon_CQK's avatar Simon_CQK Committed by GitHub
Browse files

fix(model loader): use safe_open to prevent file handle leaks. (#7684)

parent 93b6785d
......@@ -460,10 +460,12 @@ def safetensors_weights_iterator(
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
for name, param in result.items():
yield name, param
else:
result = safetensors.torch.load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
for name in f.keys():
yield name, f.get_tensor(name)
def multi_thread_safetensors_weights_iterator(
......@@ -496,7 +498,8 @@ def multi_thread_safetensors_weights_iterator(
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
else:
result = safetensors.torch.load_file(st_file, device="cpu")
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
result = {k: f.get_tensor(k) for k in f.keys()}
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment