Unverified Commit 3a028101 authored by Billy Cao's avatar Billy Cao Committed by GitHub
Browse files

[QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590)

* Allow dtype str for torch_dtype in from_pretrained

* Update docstring

* Add tests for str torch_dtype
parent 11138ca0
......@@ -2958,6 +2958,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
<Tip>
For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
......@@ -3661,9 +3663,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Since the `torch_dtype` attribute can't be found in model's config object, "
"will use torch_dtype={torch_dtype} as derived from model's weights"
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}'
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
......
......@@ -445,6 +445,18 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)
def test_model_from_config_torch_dtype_str(self):
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
self.assertEqual(model.dtype, torch.float32)
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
self.assertEqual(model.dtype, torch.float16)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment