Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
59aefe9e
Unverified
Commit
59aefe9e
authored
Jun 16, 2023
by
Will Berman
Committed by
GitHub
Jun 16, 2023
Browse files
device map legacy attention block weight conversion (#3804)
parent
3ddc2b73
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
137 additions
and
10 deletions
+137
-10
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+1
-0
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+92
-10
tests/models/test_attention_processor.py
tests/models/test_attention_processor.py
+44
-0
No files found.
src/diffusers/models/attention_processor.py
View file @
59aefe9e
...
...
@@ -78,6 +78,7 @@ class Attention(nn.Module):
self
.
upcast_softmax
=
upcast_softmax
self
.
rescale_output_factor
=
rescale_output_factor
self
.
residual_connection
=
residual_connection
self
.
dropout
=
dropout
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
...
...
src/diffusers/models/modeling_utils.py
View file @
59aefe9e
...
...
@@ -22,7 +22,7 @@ from functools import partial
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
Tensor
,
device
from
torch
import
Tensor
,
device
,
nn
from
..
import
__version__
from
..utils
import
(
...
...
@@ -646,6 +646,35 @@ class ModelMixin(torch.nn.Module):
else
:
# else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
try
:
accelerate
.
load_checkpoint_and_dispatch
(
model
,
model_file
,
device_map
,
max_memory
=
max_memory
,
offload_folder
=
offload_folder
,
offload_state_dict
=
offload_state_dict
,
dtype
=
torch_dtype
,
)
except
AttributeError
as
e
:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
if
"'Attention' object has no attribute"
in
str
(
e
):
logger
.
warn
(
f
"Taking `
{
str
(
e
)
}
` while using `accelerate.load_checkpoint_and_dispatch` to mean
{
pretrained_model_name_or_path
}
"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
model
.
_temp_convert_self_to_deprecated_attention_blocks
()
accelerate
.
load_checkpoint_and_dispatch
(
model
,
model_file
,
...
...
@@ -655,6 +684,9 @@ class ModelMixin(torch.nn.Module):
offload_state_dict
=
offload_state_dict
,
dtype
=
torch_dtype
,
)
model
.
_undo_temp_convert_self_to_deprecated_attention_blocks
()
else
:
raise
e
loading_info
=
{
"missing_keys"
:
[],
...
...
@@ -889,3 +921,53 @@ class ModelMixin(torch.nn.Module):
state_dict
[
f
"
{
path
}
.to_out.0.weight"
]
=
state_dict
.
pop
(
f
"
{
path
}
.proj_attn.weight"
)
if
f
"
{
path
}
.proj_attn.bias"
in
state_dict
:
state_dict
[
f
"
{
path
}
.to_out.0.bias"
]
=
state_dict
.
pop
(
f
"
{
path
}
.proj_attn.bias"
)
def
_temp_convert_self_to_deprecated_attention_blocks
(
self
):
deprecated_attention_block_modules
=
[]
def
recursive_find_attn_block
(
module
):
if
hasattr
(
module
,
"_from_deprecated_attn_block"
)
and
module
.
_from_deprecated_attn_block
:
deprecated_attention_block_modules
.
append
(
module
)
for
sub_module
in
module
.
children
():
recursive_find_attn_block
(
sub_module
)
recursive_find_attn_block
(
self
)
for
module
in
deprecated_attention_block_modules
:
module
.
query
=
module
.
to_q
module
.
key
=
module
.
to_k
module
.
value
=
module
.
to_v
module
.
proj_attn
=
module
.
to_out
[
0
]
# We don't _have_ to delete the old attributes, but it's helpful to ensure
# that _all_ the weights are loaded into the new attributes and we're not
# making an incorrect assumption that this model should be converted when
# it really shouldn't be.
del
module
.
to_q
del
module
.
to_k
del
module
.
to_v
del
module
.
to_out
def
_undo_temp_convert_self_to_deprecated_attention_blocks
(
self
):
deprecated_attention_block_modules
=
[]
def
recursive_find_attn_block
(
module
):
if
hasattr
(
module
,
"_from_deprecated_attn_block"
)
and
module
.
_from_deprecated_attn_block
:
deprecated_attention_block_modules
.
append
(
module
)
for
sub_module
in
module
.
children
():
recursive_find_attn_block
(
sub_module
)
recursive_find_attn_block
(
self
)
for
module
in
deprecated_attention_block_modules
:
module
.
to_q
=
module
.
query
module
.
to_k
=
module
.
key
module
.
to_v
=
module
.
value
module
.
to_out
=
nn
.
ModuleList
([
module
.
proj_attn
,
nn
.
Dropout
(
module
.
dropout
)])
del
module
.
query
del
module
.
key
del
module
.
value
del
module
.
proj_attn
tests/models/test_attention_processor.py
View file @
59aefe9e
import
tempfile
import
unittest
import
numpy
as
np
import
torch
from
diffusers
import
DiffusionPipeline
from
diffusers.models.attention_processor
import
Attention
,
AttnAddedKVProcessor
...
...
@@ -73,3 +76,44 @@ class AttnAddedKVProcessorTests(unittest.TestCase):
only_cross_attn_out
=
attn
(
**
forward_args
)
self
.
assertTrue
((
only_cross_attn_out
!=
self_and_cross_attn_out
).
all
())
class
DeprecatedAttentionBlockTests
(
unittest
.
TestCase
):
def
test_conversion_when_using_device_map
(
self
):
pipe
=
DiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
safety_checker
=
None
)
pre_conversion
=
pipe
(
"foo"
,
num_inference_steps
=
2
,
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
output_type
=
"np"
,
).
images
# the initial conversion succeeds
pipe
=
DiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-pipe"
,
device_map
=
"sequential"
,
safety_checker
=
None
)
conversion
=
pipe
(
"foo"
,
num_inference_steps
=
2
,
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
output_type
=
"np"
,
).
images
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
# save the converted model
pipe
.
save_pretrained
(
tmpdir
)
# can also load the converted weights
pipe
=
DiffusionPipeline
.
from_pretrained
(
tmpdir
,
device_map
=
"sequential"
,
safety_checker
=
None
)
after_conversion
=
pipe
(
"foo"
,
num_inference_steps
=
2
,
generator
=
torch
.
Generator
(
"cpu"
).
manual_seed
(
0
),
output_type
=
"np"
,
).
images
self
.
assertTrue
(
np
.
allclose
(
pre_conversion
,
conversion
))
self
.
assertTrue
(
np
.
allclose
(
conversion
,
after_conversion
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment