Unverified Commit b5be744d authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

Fixed issue #21039 (#21062)

Fixed issue #21039 and added test for low_cpu_mem_usage
parent e849e5bb
......@@ -2629,7 +2629,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key.startswith(prefix):
if key in list(model_state_dict.keys()):
key = key
elif f"{prefix}.key" in list(model_state_dict.keys()):
key = f"{prefix}.key"
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
......
......@@ -3166,6 +3166,27 @@ class ModelUtilsTest(TestCasePlus):
):
_ = ModelWithHead.from_pretrained(tmp_dir)
@require_torch_gpu
def test_pretrained_low_mem_new_config(self):
# Checking for 1 model(the same one which was described in the issue) .
model_ids = ["gpt2"]
for model_id in model_ids:
model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
model_config.n_layer = 48
model_config.n_head = 25
model_config.n_embd = 1600
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_id,
config=model_config,
ignore_mismatched_sizes=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
@require_torch
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment