Unverified Commit 66db33dd authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix mismatching loading in from_pretrained with/without accelerate (#28414)

* fix mismatching behavior in from_pretrained with/without accelerate

* meaningful refactor

* remove added space

* add test

* fix model on the hub

* comment

* use tiny model

* style
parent 002566f3
...@@ -756,18 +756,23 @@ def _load_state_dict_into_meta_model( ...@@ -756,18 +756,23 @@ def _load_state_dict_into_meta_model(
else: else:
param = param.to(dtype) param = param.to(dtype)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
if dtype is None: # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
old_param = model # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
splits = param_name.split(".") old_param = model
for split in splits: splits = param_name.split(".")
old_param = getattr(old_param, split) for split in splits:
if old_param is None: old_param = getattr(old_param, split)
break if old_param is None:
break
if old_param is not None:
if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype) param = param.to(old_param.dtype)
if old_param.is_contiguous():
param = param.contiguous()
set_module_kwargs["value"] = param set_module_kwargs["value"] = param
if device_map is None: if device_map is None:
......
...@@ -34,6 +34,7 @@ from requests.exceptions import HTTPError ...@@ -34,6 +34,7 @@ from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModel, AutoModel,
OwlViTForObjectDetection,
PretrainedConfig, PretrainedConfig,
is_torch_available, is_torch_available,
logging, logging,
...@@ -835,6 +836,23 @@ class ModelUtilsTest(TestCasePlus): ...@@ -835,6 +836,23 @@ class ModelUtilsTest(TestCasePlus):
outputs2 = new_model_with_offload(inputs) outputs2 = new_model_with_offload(inputs)
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu())) self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
@slow
@require_torch
def test_from_pretrained_non_contiguous_checkpoint(self):
# See: https://github.com/huggingface/transformers/pull/28414
# Tiny models on the Hub have contiguous weights, contrarily to google/owlvit
model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight")
self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
model = OwlViTForObjectDetection.from_pretrained(
"fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto"
)
self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
model.save_pretrained(tmp_dir, safe_serialization=True)
def test_cached_files_are_used_when_internet_is_down(self): def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment