Unverified Commit a536e775 authored by vincedovy's avatar vincedovy Committed by GitHub
Browse files

Fix json WindowsPath crash (#8662)



* Add check for WindowsPath in to_json_string

On Windows, os.path.join returns a WindowsPath. to_json_string does not convert this from a WindowsPath to a string. Added check for WindowsPath to to_json_saveable.

* Remove extraneous convert to string in test_check_path_types (tests/others/test_config.py)

* Fix style issues in tests/others/test_config.py

* Add unit test to test_config.py to verify that PosixPath and WindowsPath (depending on system) both work when converted to JSON

* Remove distinction between PosixPath and WindowsPath in ConfigMixIn.to_json_string(). Conditional now tests for Path, and uses Path.as_posix() to convert to string.

---------
Co-authored-by: default avatarVincent Dovydaitis <vincedovy@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 3b01d72a
......@@ -23,7 +23,7 @@ import json
import os
import re
from collections import OrderedDict
from pathlib import PosixPath
from pathlib import Path
from typing import Any, Dict, Tuple, Union
import numpy as np
......@@ -587,8 +587,8 @@ class ConfigMixin:
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, PosixPath):
value = str(value)
elif isinstance(value, Path):
value = value.as_posix()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
......
......@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest
from pathlib import Path
from diffusers import (
DDIMScheduler,
......@@ -91,6 +93,14 @@ class SampleObject4(ConfigMixin):
pass
class SampleObjectPaths(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(self, test_file_1=Path("foo/bar"), test_file_2=Path("foo bar\\bar")):
pass
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
......@@ -286,3 +296,11 @@ class ConfigTester(unittest.TestCase):
# Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5]
assert new_config_2.config.e == [1, 3]
def test_check_path_types(self):
# Verify that we get a string returned from a WindowsPath or PosixPath (depending on system)
config = SampleObjectPaths()
json_string = config.to_json_string()
result = json.loads(json_string)
assert result["test_file_1"] == config.config.test_file_1.as_posix()
assert result["test_file_2"] == config.config.test_file_2.as_posix()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment