Unverified Commit 8cfc6678 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Adapt find_tied_parameters to handle breaking change in Accelerate (#22360)

parent 204737fc
...@@ -154,7 +154,12 @@ def get_keys_to_not_convert(model): ...@@ -154,7 +154,12 @@ def get_keys_to_not_convert(model):
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights() tied_model.tie_weights()
tied_keys = list(find_tied_parameters(tied_model).values()) tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = list(tied_params.values())
else:
tied_keys = sum([x[1:] for x in tied_params], [])
has_tied_params = len(tied_keys) > 0 has_tied_params = len(tied_keys) > 0
# Check if it is a base model # Check if it is a base model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment