Unverified Commit 59aefe9e authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

device map legacy attention block weight conversion (#3804)

parent 3ddc2b73
...@@ -78,6 +78,7 @@ class Attention(nn.Module): ...@@ -78,6 +78,7 @@ class Attention(nn.Module):
self.upcast_softmax = upcast_softmax self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection self.residual_connection = residual_connection
self.dropout = dropout
# we make use of this private variable to know whether this class is loaded # we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly # with an deprecated state dict so that we can convert it on the fly
......
...@@ -22,7 +22,7 @@ from functools import partial ...@@ -22,7 +22,7 @@ from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device, nn
from .. import __version__ from .. import __version__
from ..utils import ( from ..utils import (
...@@ -646,15 +646,47 @@ class ModelMixin(torch.nn.Module): ...@@ -646,15 +646,47 @@ class ModelMixin(torch.nn.Module):
else: # else let accelerate handle loading and dispatching. else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map # Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU # by default the device_map is None and the weights are loaded on the CPU
accelerate.load_checkpoint_and_dispatch( try:
model, accelerate.load_checkpoint_and_dispatch(
model_file, model,
device_map, model_file,
max_memory=max_memory, device_map,
offload_folder=offload_folder, max_memory=max_memory,
offload_state_dict=offload_state_dict, offload_folder=offload_folder,
dtype=torch_dtype, offload_state_dict=offload_state_dict,
) dtype=torch_dtype,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
if "'Attention' object has no attribute" in str(e):
logger.warn(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
else:
raise e
loading_info = { loading_info = {
"missing_keys": [], "missing_keys": [],
...@@ -889,3 +921,53 @@ class ModelMixin(torch.nn.Module): ...@@ -889,3 +921,53 @@ class ModelMixin(torch.nn.Module):
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict: if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
def _temp_convert_self_to_deprecated_attention_blocks(self):
deprecated_attention_block_modules = []
def recursive_find_attn_block(module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)
for sub_module in module.children():
recursive_find_attn_block(sub_module)
recursive_find_attn_block(self)
for module in deprecated_attention_block_modules:
module.query = module.to_q
module.key = module.to_k
module.value = module.to_v
module.proj_attn = module.to_out[0]
# We don't _have_ to delete the old attributes, but it's helpful to ensure
# that _all_ the weights are loaded into the new attributes and we're not
# making an incorrect assumption that this model should be converted when
# it really shouldn't be.
del module.to_q
del module.to_k
del module.to_v
del module.to_out
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
deprecated_attention_block_modules = []
def recursive_find_attn_block(module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)
for sub_module in module.children():
recursive_find_attn_block(sub_module)
recursive_find_attn_block(self)
for module in deprecated_attention_block_modules:
module.to_q = module.query
module.to_k = module.key
module.to_v = module.value
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
del module.query
del module.key
del module.value
del module.proj_attn
import tempfile
import unittest import unittest
import numpy as np
import torch import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
...@@ -73,3 +76,44 @@ class AttnAddedKVProcessorTests(unittest.TestCase): ...@@ -73,3 +76,44 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
only_cross_attn_out = attn(**forward_args) only_cross_attn_out = attn(**forward_args)
self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all()) self.assertTrue((only_cross_attn_out != self_and_cross_attn_out).all())
class DeprecatedAttentionBlockTests(unittest.TestCase):
def test_conversion_when_using_device_map(self):
pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None)
pre_conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
# the initial conversion succeeds
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None
)
conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
with tempfile.TemporaryDirectory() as tmpdir:
# save the converted model
pipe.save_pretrained(tmpdir)
# can also load the converted weights
pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None)
after_conversion = pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
self.assertTrue(np.allclose(pre_conversion, conversion))
self.assertTrue(np.allclose(conversion, after_conversion))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment