Unverified Commit ffb6777a authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

remove format check for safetensors file (#10864)

remove check
parent 85fcbaf3
...@@ -134,19 +134,6 @@ def _fetch_remapped_cls_from_config(config, old_class): ...@@ -134,19 +134,6 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class return old_class
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
"""
Check format of the archive
"""
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in format_list:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]): def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
""" """
Find the device of param_name from the device_map. Find the device of param_name from the device_map.
...@@ -183,7 +170,6 @@ def load_state_dict( ...@@ -183,7 +170,6 @@ def load_state_dict(
# tensors are loaded on cpu # tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm: with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm) return safetensors.torch.load(mm)
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
if disable_mmap: if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read()) return safetensors.torch.load(open(checkpoint_file, "rb").read())
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment