"docs/vscode:/vscode.git/clone" did not exist on "5761ceb35a9ae0bd9e49a59438c725c61cda4f10"
Unverified Commit 5c186003 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix low cpu mem usage tests (#30808)

* Fix tests

* fix udop failing test

* remove skip

* style
parent 934e1b84
......@@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel):
# get weights from encoder position bias
self.relative_bias = self._get_relative_bias(config)
# tie weights of original position bias of encoder
def _tie_weights(self):
for bias in self.relative_bias.biases:
if isinstance(bias, RelativePositionBias1D):
self._tie_or_clone_weights(
......
......@@ -21,7 +21,6 @@ import os.path
import random
import re
import tempfile
import unittest
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple
......@@ -444,7 +443,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
......@@ -457,7 +455,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
......@@ -471,7 +468,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
......@@ -482,6 +478,8 @@ class ModelTesterMixin:
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
from accelerate.utils.modeling import named_module_tensors
# Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path,
......@@ -496,16 +494,13 @@ class ModelTesterMixin:
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, param in model_low_usage.named_parameters():
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
self.assertNotEqual(
param.device,
tensor.device,
torch.device("meta"),
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
)
# Tests moving the model to a device other than meta.
model_low_usage.to(torch_device)
# Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment