Unverified Commit da971b22 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Keep relevant weights in fp32 when `model._keep_in_fp32_modules` is set even...

Keep relevant weights in fp32 when `model._keep_in_fp32_modules` is set even when `accelerate` is not installed (#26225)

* fix bug where weight would not be kept in fp32

* nit

* address review comments

* fix test
parent e3a4bd2b
......@@ -2950,26 +2950,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (
(cls._keep_in_fp32_modules is not None)
and is_accelerate_available()
and (torch_dtype == torch.float16 or load_in_4bit or load_in_8bit)
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
torch_dtype == torch.float16 or load_in_4bit or load_in_8bit
)
if (
(cls._keep_in_fp32_modules is not None)
and not is_accelerate_available()
and torch_dtype == torch.float16
):
logger.warning(
"For stability purposes, it is recommended to have accelerate installed when using this model in"
" torch.float16, please install it with `pip install accelerate`"
)
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = list(state_dict.keys())
if low_cpu_mem_usage or use_keep_in_fp32_modules:
if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()):
# In case some weights need to be kept in float32 and accelerate is not installed,
# we later on want to take the path where state_dict is not None, that is the one
# that do not require accelerate.
state_dict = None
config.name_or_path = pretrained_model_name_or_path
......@@ -2990,7 +2982,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Check first if we are `from_pt`
if use_keep_in_fp32_modules:
low_cpu_mem_usage = True
if is_accelerate_available():
low_cpu_mem_usage = True
keep_in_fp32_modules = model._keep_in_fp32_modules
else:
keep_in_fp32_modules = []
......@@ -3465,7 +3458,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if keep_in_fp32_modules is not None:
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param = param.to(torch.float32)
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
......@@ -3592,7 +3586,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model,
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
......
......@@ -1046,15 +1046,30 @@ class T5ModelFp16Tests(unittest.TestCase):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
orig_import = __import__
accelerate_mock = unittest.mock.Mock()
# mock import of accelerate
def import_accelerate_mock(name, *args, **kwargs):
if name == "accelerate":
if accelerate_available:
return accelerate_mock
else:
raise ImportError
return orig_import(name, *args, **kwargs)
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
accelerate_available = False
# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment