Unverified Commit 6307312d authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Add check for tied parameters (#24029)

* Add check for tied parameters

* Fix style

* fix style

* Fix versioning

* Change if to elif
parent 7da3ce04
......@@ -96,6 +96,10 @@ if is_accelerate_available():
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
if version.parse(accelerate_version) > version.parse("0.19.0"):
from accelerate.utils import check_tied_parameters_on_same_device
else:
check_tied_parameters_on_same_device = None
else:
find_tied_parameters = None
......@@ -2824,6 +2828,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
del device_map_without_lm_head
elif device_map is not None:
model.tie_weights()
tied_params = find_tied_parameters(model)
# check if we don't have tied param in different devices
if check_tied_parameters_on_same_device is not None:
check_tied_parameters_on_same_device(tied_params, device_map)
if from_tf:
if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors
......@@ -3015,6 +3026,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
if find_tied_parameters is not None:
model.tie_weights()
tied_params = find_tied_parameters(model)
else:
tied_params = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment