Unverified Commit 5c186003 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix low cpu mem usage tests (#30808)

* Fix tests

* fix udop failing test

* remove skip

* style
parent 934e1b84
...@@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel): ...@@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel):
# get weights from encoder position bias # get weights from encoder position bias
self.relative_bias = self._get_relative_bias(config) self.relative_bias = self._get_relative_bias(config)
# tie weights of original position bias of encoder def _tie_weights(self):
for bias in self.relative_bias.biases: for bias in self.relative_bias.biases:
if isinstance(bias, RelativePositionBias1D): if isinstance(bias, RelativePositionBias1D):
self._tie_or_clone_weights( self._tie_or_clone_weights(
......
...@@ -21,7 +21,6 @@ import os.path ...@@ -21,7 +21,6 @@ import os.path
import random import random
import re import re
import tempfile import tempfile
import unittest
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -444,7 +443,6 @@ class ModelTesterMixin: ...@@ -444,7 +443,6 @@ class ModelTesterMixin:
@slow @slow
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage(self): def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path: with tempfile.TemporaryDirectory() as saved_model_path:
...@@ -457,7 +455,6 @@ class ModelTesterMixin: ...@@ -457,7 +455,6 @@ class ModelTesterMixin:
@slow @slow
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_checkpoints(self): def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path: with tempfile.TemporaryDirectory() as saved_model_path:
...@@ -471,7 +468,6 @@ class ModelTesterMixin: ...@@ -471,7 +468,6 @@ class ModelTesterMixin:
@slow @slow
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_no_safetensors(self): def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path: with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -482,6 +478,8 @@ class ModelTesterMixin: ...@@ -482,6 +478,8 @@ class ModelTesterMixin:
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path) self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path): def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
from accelerate.utils.modeling import named_module_tensors
# Load the low usage and the normal models. # Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained( model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path, saved_model_path,
...@@ -496,16 +494,13 @@ class ModelTesterMixin: ...@@ -496,16 +494,13 @@ class ModelTesterMixin:
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any # subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded. # remaining params that were not properly loaded.
for name, param in model_low_usage.named_parameters(): for name, tensor in named_module_tensors(model_low_usage, recurse=True):
self.assertNotEqual( self.assertNotEqual(
param.device, tensor.device,
torch.device("meta"), torch.device("meta"),
"Parameter '" + name + "' has not been properly loaded and has device=meta.", "Tensor '" + name + "' has not been properly loaded and has device=meta.",
) )
# Tests moving the model to a device other than meta.
model_low_usage.to(torch_device)
# Check that the parameters are equal. # Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()): for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEquals(p1.data.ne(p2.data).sum(), 0) self.assertEquals(p1.data.ne(p2.data).sum(), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment