Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
59aefe9e
Unverified
Commit
59aefe9e
authored
Jun 16, 2023
by
Will Berman
Committed by
GitHub
Jun 16, 2023
Browse files
device map legacy attention block weight conversion (#3804)
parent
3ddc2b73
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
137 additions
and
10 deletions
+137
-10
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+1
-0
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+92
-10
tests/models/test_attention_processor.py
tests/models/test_attention_processor.py
+44
-0
No files found.
src/diffusers/models/attention_processor.py
View file @
59aefe9e
...
...
@@ -78,6 +78,7 @@ class Attention(nn.Module):
self
.
upcast_softmax
=
upcast_softmax
self
.
rescale_output_factor
=
rescale_output_factor
self
.
residual_connection
=
residual_connection
self
.
dropout
=
dropout
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
...
...
src/diffusers/models/modeling_utils.py
View file @
59aefe9e
...
...
@@ -22,7 +22,7 @@ from functools import partial
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
Tensor
,
device
from
torch
import
Tensor
,
device
,
nn
from
..
import
__version__
from
..utils
import
(
...
...
@@ -646,15 +646,47 @@ class ModelMixin(torch.nn.Module):
else
:
# else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
accelerate
.
load_checkpoint_and_dispatch
(
model
,
model_file
,
device_map
,
max_memory
=
max_memory
,
offload_folder
=
offload_folder
,
offload_state_dict
=
offload_state_dict
,
dtype
=
torch_dtype
,
)
try
:
accelerate
.
load_checkpoint_and_dispatch
(
model
,
model_file
,
device_map
,
max_memory
=
max_memory
,
offload_folder
=
offload_folder
,
offload_state_dict
=
offload_state_dict
,
dtype
=
torch_dtype
,
)
except
AttributeError
as
e
:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
if
"'Attention' object has no attribute"
in
str
(
e
):
logger
.
warn
(
f
"Taking `
{
str
(
e
)
}
` while using `accelerate.load_checkpoint_and_dispatch` to mean
{
pretrained_model_name_or_path
}
"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
model
.
_temp_convert_self_to_deprecated_attention_blocks
()
accelerate
.
load_checkpoint_and_dispatch
(
model
,
model_file
,
device_map
,
max_memory
=
max_memory
,
offload_folder
=
offload_folder
,
offload_state_dict
=
offload_state_dict
,
dtype
=
torch_dtype
,
)
model
.
_undo_temp_convert_self_to_deprecated_attention_blocks
()
else
:
raise
e
loading_info
=
{
"missing_keys"
:
[],
...
...
@@ -889,3 +921,53 @@ class ModelMixin(torch.nn.Module):
state_dict
[
f
"
{
path
}
.to_out.0.weight"
]
=
state_dict
.
pop
(
f
"
{
path
}
.proj_attn.weight"
)
if
f
"
{
path
}
.proj_attn.bias"
in
state_dict
:
state_dict
[
f
"
{
path
}
.to_out.0.bias"
]
=
state_dict
.
pop
(
f
"
{
path
}
.proj_attn.bias"
)
def
_temp_convert_self_to_deprecated_attention_blocks
(
self
):
deprecated_attention_block_modules
=
[]
def
recursive_find_attn_block
(
module
):
if
hasattr
(
module
,
"_from_deprecated_attn_block"
)
and
module
.
_from_deprecated_attn_block
:
deprecated_attention_block_modules
.
append
(
module
)
for
sub_module
in
module
.
children
():
recursive_find_attn_block
(
sub_module
)
recursive_find_attn_block
(
self
)
for
module
in
deprecated_attention_block_modules
:
module
.
query
=
module
.
to_q
module
.
key
=
module
.
to_k
module
.
value
=
module
.
to_v
module
.
proj_attn
=
module
.
to_out
[
0
]
# We don't _have_ to delete the old attributes, but it's helpful to ensure
# that _all_ the weights are loaded into the new attributes and we're not
# making an incorrect assumption that this model should be converted when
# it really shouldn't be.
del
module
.
to_q
del
module
.
to_k
del
module
.
to_v
del
module
.
to_out
def
_undo_temp_convert_self_to_deprecated_attention_blocks
(
self
):
deprecated_attention_block_modules
=
[]
def
recursive_find_attn_block
(
module
):
if
hasattr
(
module
,
"_from_deprecated_attn_block"
)
and
module
.
_from_deprecated_attn_block
:
deprecated_attention_block_modules
.
append
(
module
)
for
sub_module
in
module
.
children
():
recursive_find_attn_block
(
sub_module
)
recursive_find_attn_block
(
self
)
for
module
in
deprecated_attention_block_modules
:
module
.
to_q
=
module
.
query
module
.
to_k
=
module
.
key
module
.
to_v
=
module
.
value
module
.
to_out
=
nn
.
ModuleList
([
module
.
proj_attn
,
nn
.
Dropout
(
module
.
dropout
)])
del
module
.
query
del
module
.
key
del
module
.
value
del
module
.
proj_attn
tests/models/test_attention_processor.py
View file @
59aefe9e
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
DiffusionPipeline
from
diffusers.models.attention_processor
import
Attention
,
AttnAddedKVProcessor
...
...
@@ -73,3 +76,44 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
only_cross_attn_out
=
attn
(
**
forward_args
)
self
.
assertTrue
((
only_cross_attn_out
!=
self_and_cross_attn_out
).
all
())
class
DeprecatedAttentionBlockTests
(
unittest
.
TestCase
):
def
test_conversion_when_using_device_map
(
self
):
pipe
=
DiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
)
pre_conversion
=
pipe
(
"foo"
,
num_inference_steps
=
2
,
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
output_type
=
"np"
,
).
images
# the initial conversion succeeds
pipe
=
DiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
device_map
=
"sequential"
,
safety_checker
=
None
)
conversion
=
pipe
(
"foo"
,
num_inference_steps
=
2
,
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
output_type
=
"np"
,
).
images
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
# save the converted model
pipe
.
save_pretrained
(
tmpdir
)
# can also load the converted weights
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
device_map
=
"sequential"
,
safety_checker
=
None
)
after_conversion
=
pipe
(
"foo"
,
num_inference_steps
=
2
,
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
output_type
=
"np"
,
).
images
self
.
assertTrue
(
np
.
allclose
(
pre_conversion
,
conversion
))
self
.
assertTrue
(
np
.
allclose
(
conversion
,
after_conversion
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment