Unverified Commit e946260c authored by twaka's avatar twaka Committed by GitHub
Browse files

use get_tensor in safe_open (#1696)

parent edb30558
......@@ -243,8 +243,8 @@ def hf_model_weights_iterator(
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys():
param = f.get_slice(name)
yield name, convert_pyslice_to_tensor(param)
param = f.get_tensor(name)
yield name, param
else:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
......@@ -265,12 +265,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
tensor first.
"""
if not isinstance(x, torch.Tensor):
try:
x = x[:]
except IndexError:
# IndexError happens when the tensor is empty.
# transformer.h.0.attn.masked_bias is empty in some gpt2 models.
return torch.Tensor()
x = x[:]
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment