Unverified Commit d384265d authored by Philipp Hasper's avatar Philipp Hasper Committed by GitHub
Browse files

Fixed is_safetensors_compatible() handling of windows path separators (#5650)

Closes #4665
parent 11c12566
...@@ -158,9 +158,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) - ...@@ -158,9 +158,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
continue continue
if extension == ".bin": if extension == ".bin":
pt_filenames.append(filename) pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors": elif extension == ".safetensors":
sf_filenames.add(filename) sf_filenames.add(os.path.normpath(filename))
for filename in pt_filenames: for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
...@@ -172,9 +172,8 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) - ...@@ -172,9 +172,8 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
else: else:
filename = filename filename = filename
expected_sf_filename = os.path.join(path, filename) expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors" expected_sf_filename = f"{expected_sf_filename}.safetensors"
if expected_sf_filename not in sf_filenames: if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found") logger.warning(f"{expected_sf_filename} not found")
return False return False
...@@ -1774,7 +1773,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1774,7 +1773,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
): ):
raise EnvironmentError( raise EnvironmentError(
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
) )
if from_flax: if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment