"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "ff8f7082a3ec1d71dd55512c5528493f9e2129f5"
Unverified Commit 7acf8345 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] Enable more general testing for `torch.compile()` with LoRA hotswapping (#11322)

* refactor hotswap tester.

* fix seeds..

* add to nightly ci.

* move comment.

* move to nightly
parent 599c8871
...@@ -142,6 +142,7 @@ jobs: ...@@ -142,6 +142,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8 CUBLAS_WORKSPACE_CONFIG: :16:8
RUN_COMPILE: yes
run: | run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \ -s -v -k "not Flax and not Onnx" \
......
...@@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import ( ...@@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated, backend_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
backend_synchronize, backend_synchronize,
floats_tensor,
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -1754,7 +1753,7 @@ class TorchCompileTesterMixin: ...@@ -1754,7 +1753,7 @@ class TorchCompileTesterMixin:
@require_peft_backend @require_peft_backend
@require_peft_version_greater("0.14.0") @require_peft_version_greater("0.14.0")
@is_torch_compile @is_torch_compile
class TestLoraHotSwappingForModel(unittest.TestCase): class LoraHotSwappingForModelTesterMixin:
"""Test that hotswapping does not result in recompilation on the model directly. """Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
...@@ -1775,48 +1774,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase): ...@@ -1775,48 +1774,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def get_small_unet(self): def get_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 16,
}
model = UNet2DConditionModel(**init_dict)
return model.to(torch_device)
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py # from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig from peft import LoraConfig
unet_lora_config = LoraConfig( lora_config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
target_modules=target_modules, target_modules=target_modules,
init_lora_weights=False, init_lora_weights=False,
use_dora=False, use_dora=False,
) )
return unet_lora_config return lora_config
def get_dummy_input(self):
# from UNet2DConditionModelTests
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} def get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
""" """
...@@ -1834,23 +1809,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase): ...@@ -1834,23 +1809,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
fine. fine.
""" """
# create 2 adapters with different ranks and alphas # create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input() torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1 alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1]) max_rank = max([rank0, rank1])
if target_modules1 is None: if target_modules1 is None:
target_modules1 = target_modules0[:] target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0) lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1) lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
unet = self.get_small_unet() model.add_adapter(lora_config0, adapter_name="adapter0")
unet.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode(): with torch.inference_mode():
output0_before = unet(**dummy_input)["sample"] torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
unet.add_adapter(lora_config1, adapter_name="adapter1") model.add_adapter(lora_config1, adapter_name="adapter1")
unet.set_adapter("adapter1") model.set_adapter("adapter1")
with torch.inference_mode(): with torch.inference_mode():
output1_before = unet(**dummy_input)["sample"] torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks: # sanity checks:
tol = 5e-3 tol = 5e-3
...@@ -1860,40 +1839,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase): ...@@ -1860,40 +1839,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints # save the adapter checkpoints
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del unet del model
# load the first adapter # load the first adapter
unet = self.get_small_unet() torch.manual_seed(0)
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1): if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical # no need to prepare if the model is not compiled or if the ranks are identical
unet.enable_lora_hotswap(target_rank=max_rank) model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile: if do_compile:
unet = torch.compile(unet, mode="reduce-overhead") model = torch.compile(model, mode="reduce-overhead")
with torch.inference_mode(): with torch.inference_mode():
output0_after = unet(**dummy_input)["sample"] output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter # hotswap the 2nd adapter
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation # we need to call forward to potentially trigger recompilation
with torch.inference_mode(): with torch.inference_mode():
output1_after = unet(**dummy_input)["sample"] output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name # check error when not passing valid adapter name
name = "does-not-exist" name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg): with self.assertRaisesRegex(ValueError, msg):
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_model(self, rank0, rank1): def test_hotswapping_model(self, rank0, rank1):
...@@ -1910,6 +1892,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase): ...@@ -1910,6 +1892,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
# It's important to add this context to raise an error on recompilation # It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"] target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True): with torch._dynamo.config.patch(error_on_recompile=True):
...@@ -1917,52 +1902,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase): ...@@ -1917,52 +1902,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
# It's important to add this context to raise an error on recompilation # It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"] target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True): with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
target_modules.append(self.get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self): def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter # ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
unet.add_adapter(lora_config) model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg): with self.assertRaisesRegex(RuntimeError, msg):
unet.enable_lora_hotswap(target_rank=32) model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self): def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
# ensure that enable_lora_hotswap is called before loading the first adapter # ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger from diffusers.loaders.peft import logger
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
unet.add_adapter(lora_config) model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = ( msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
) )
with self.assertLogs(logger=logger, level="WARNING") as cm: with self.assertLogs(logger=logger, level="WARNING") as cm:
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output) assert any(msg in log for log in cm.output)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning # check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
unet.add_adapter(lora_config) model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings warnings.simplefilter("always") # Capture all warnings
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error # check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
unet.add_adapter(lora_config) model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg): with self.assertRaisesRegex(ValueError, msg):
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self): def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log # check the error and log
......
...@@ -22,7 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor ...@@ -22,7 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism() enable_full_determinism()
...@@ -78,7 +78,9 @@ def create_flux_ip_adapter_state_dict(model): ...@@ -78,7 +78,9 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): class FluxTransformerTests(
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
):
model_class = FluxTransformer2DModel model_class = FluxTransformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small. # We override the items here because the transformer under consideration is small.
......
...@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import ( ...@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
if is_peft_available(): if is_peft_available():
...@@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): ...@@ -350,7 +350,9 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class UNet2DConditionModelTests(
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = UNet2DConditionModel model_class = UNet2DConditionModel
main_input_name = "sample" main_input_name = "sample"
# We override the items here because the unet under consideration is small. # We override the items here because the unet under consideration is small.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment