Unverified Commit 7e00247f authored by feifang24's avatar feifang24 Committed by GitHub
Browse files

check for key 'torch.dtype' in nested dicts in config (#16065)

parent 5d2fed2e
...@@ -849,12 +849,15 @@ class PretrainedConfig(PushToHubMixin): ...@@ -849,12 +849,15 @@ class PretrainedConfig(PushToHubMixin):
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
""" """
Checks whether the passed dictionary has a *torch_dtype* key and if it's not None, converts torch.dtype to a Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
stored in the json format. string, which can then be stored in the json format.
""" """
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
@classmethod @classmethod
def register_for_auto_class(cls, auto_class="AutoConfig"): def register_for_auto_class(cls, auto_class="AutoConfig"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment