"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b4727a1216bb21df2795e973063ed07202235d7e"
Unverified Commit bbbc5c15 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[AutoModel] fix `torch_dtype=auto` in `from_pretrained` (#23379)



* [automodel] fix torch_dtype=auto in from_pretrained

* add test

* fix logic

* Update src/transformers/models/auto/auto_factory.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8a588093
......@@ -435,19 +435,24 @@ class _BaseAutoModelClass:
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
if not isinstance(config, PretrainedConfig):
kwargs_copy = copy.deepcopy(kwargs)
kwargs_orig = copy.deepcopy(kwargs)
# ensure not to pollute the config object with torch_dtype="auto" - since it's
# meaningless in the context of the config object - torch.dtype values are acceptable
if kwargs_copy.get("torch_dtype", None) == "auto":
_ = kwargs_copy.pop("torch_dtype")
if kwargs.get("torch_dtype", None) == "auto":
_ = kwargs.pop("torch_dtype")
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**hub_kwargs,
**kwargs_copy,
**kwargs,
)
# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
raise ValueError(
......
......@@ -2920,6 +2920,10 @@ class ModelUtilsTest(TestCasePlus):
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float16)
# 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment