Unverified Commit a5cc30d7 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

fix tied_params for meta tensor (#25101)

* fix tied_params for meta tensor

* remove duplicate
parent f1deb21f
......@@ -3060,13 +3060,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
unexpected_keys = list(unexpected_keys - model_buffers)
model.tie_weights()
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor) if tensor.device != torch.device("meta") else id(tensor)
ptrs[id_tensor].append(name)
if device_map is None:
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
for group in tied_params:
if remove_prefix_from_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment