Unverified Commit 7242b5ff authored by Benjamin Bossan's avatar Benjamin Bossan Committed by GitHub
Browse files

FIX Test to ignore warning for enable_lora_hotswap (#12421)

I noticed that the test should be for the option check_compiled="ignore"
but it was using check_compiled="warn". This has been fixed, now the
correct argument is passed.

However, the fact that the test passed means that it was incorrect to
begin with. The way that logs are collected does not collect the
logger.warning call here (not sure why). To amend this, I'm now using
assertNoLogs. With this change, the test correctly fails when the wrong
argument is passed.
parent b4297967
...@@ -25,7 +25,6 @@ import traceback ...@@ -25,7 +25,6 @@ import traceback
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import uuid import uuid
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -2373,14 +2372,15 @@ class LoraHotSwappingForModelTesterMixin: ...@@ -2373,14 +2372,15 @@ class LoraHotSwappingForModelTesterMixin:
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning # check possibility to ignore the error/warning
from diffusers.loaders.peft import logger
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"]) lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config) model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w: # note: assertNoLogs requires Python 3.10+
warnings.simplefilter("always") # Capture all warnings with self.assertNoLogs(logger, level="WARNING"):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn") model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error # check that wrong argument value raises an error
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment