"docs/source/model_doc/bart.mdx" did not exist on "d0422de5634d3b14ac5159d7f8a2ab9336821d22"
Unverified Commit 7e00247f authored by feifang24's avatar feifang24 Committed by GitHub
Browse files

check for key 'torch.dtype' in nested dicts in config (#16065)

parent 5d2fed2e
......@@ -849,12 +849,15 @@ class PretrainedConfig(PushToHubMixin):
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary has a *torch_dtype* key and if it's not None, converts torch.dtype to a
string of just the type. For example, `torch.float32` get converted into *"float32"* string, which can then be
stored in the json format.
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
@classmethod
def register_for_auto_class(cls, auto_class="AutoConfig"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment