Unverified Commit 2f550758 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[from_pretrained] extend `torch_dtype="auto"` to look up `config.torch_dtype`...

[from_pretrained] extend `torch_dtype="auto"` to look up `config.torch_dtype` first, expand docs (#21524)

* [from_pretrained] expand on torch_dtype entry

* fold 4 into 1

* style

* support torch_dtype='config' plus tests

* style

* oops

* fold config into auto, fix bug

* fix check

* better log

* better log

* clean up
parent 9e40bba6
......@@ -1904,7 +1904,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
......@@ -1932,8 +1931,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
are:
1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
`dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified
- the model will get loaded in `torch.float` (fp32).
2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be
attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
<Tip>
For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
reach out to the authors and ask them to add this information to the model's card and to insert the
`torch_dtype` entry in `config.json` on the hub.
</Tip>
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
......@@ -2098,10 +2116,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
)
if torch_dtype == "auto" or torch_dtype != torch.float16:
if torch_dtype != torch.float16:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
logger.warning(
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
"requirements of `bitsandbytes` to enable model loading in mixed int8. "
"Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning."
)
torch_dtype = torch.float16
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
if device_map is None:
raise ValueError(
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
......@@ -2388,17 +2411,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
torch_dtype = config.torch_dtype
logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object")
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
logger.info(
"Since the `torch_dtype` attribute can't be found in model's config object, "
"will use torch_dtype={torch_dtype} as derived from model's weights"
)
else:
raise ValueError(
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}'
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import copy
import importlib
from collections import OrderedDict
......@@ -431,12 +432,18 @@ class _BaseAutoModelClass:
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
if not isinstance(config, PretrainedConfig):
kwargs_copy = copy.deepcopy(kwargs)
# ensure not to pollute the config object with torch_dtype="auto" - since it's
# meaningless in the context of the config object - torch.dtype values are acceptable
if kwargs_copy.get("torch_dtype", None) == "auto":
_ = kwargs_copy.pop("torch_dtype")
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**hub_kwargs,
**kwargs,
**kwargs_copy,
)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code:
......
......@@ -2785,7 +2785,6 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch
def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
......@@ -2804,7 +2803,6 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)
@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument
......@@ -2818,11 +2816,25 @@ class ModelUtilsTest(TestCasePlus):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertEqual(model.dtype, torch.float32)
def remove_torch_dtype(model_path):
file = f"{model_path}/config.json"
with open(file, "r", encoding="utf-8") as f:
s = json.load(f)
s.pop("torch_dtype")
with open(file, "w", encoding="utf-8") as f:
json.dump(s, f)
# test the default fp32 save_pretrained => from_pretrained cycle
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
self.assertEqual(model.dtype, torch.float32)
# test with auto-detection
# 1. test torch_dtype="auto" via `config.torch_dtype`
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# 2. test torch_dtype="auto" via auto-derivation
# now remove the torch_dtype entry from config.json and try "auto" again which should
# perform auto-derivation from weights
remove_torch_dtype(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
......@@ -2833,24 +2845,32 @@ class ModelUtilsTest(TestCasePlus):
# test fp16 save_pretrained, loaded with auto-detection
model = model.half()
model.save_pretrained(model_path)
# 1. test torch_dtype="auto" via `config.torch_dtype`
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, torch.float16)
self.assertEqual(model.dtype, torch.float16)
# tests `config.torch_dtype` saving
with open(f"{model_path}/config.json") as f:
config_dict = json.load(f)
self.assertEqual(config_dict["torch_dtype"], "float16")
# 2. test torch_dtype="auto" via auto-derivation
# now same with using config info
remove_torch_dtype(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# test AutoModel separately as it goes through a different path
# test auto-detection
# test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
# test that the config object didn't get polluted with torch_dtype="auto"
# there was a bug that after this call we ended up with config.torch_dtype=="auto"
self.assertNotEqual(model.config.torch_dtype, "auto")
# now test the outcome
self.assertEqual(model.dtype, torch.float32)
# test forcing an explicit dtype
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment