Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d9316bf8
Unverified
Commit
d9316bf8
authored
Jul 04, 2022
by
Anton Lozhkov
Committed by
GitHub
Jul 04, 2022
Browse files
Fix mutable proj_out weight in the Attention layer (#73)
* Catch unused params in DDP * Fix proj_out, add test
parent
3abf4bc4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
8 deletions
+29
-8
examples/train_unconditional.py
examples/train_unconditional.py
+8
-2
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+5
-5
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+15
-0
No files found.
examples/train_unconditional.py
View file @
d9316bf8
...
...
@@ -4,7 +4,7 @@ import os
import
torch
import
torch.nn.functional
as
F
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
,
DistributedDataParallelKwargs
from
accelerate.logging
import
get_logger
from
datasets
import
load_dataset
from
diffusers
import
DDIMPipeline
,
DDIMScheduler
,
UNetModel
...
...
@@ -27,8 +27,14 @@ logger = get_logger(__name__)
def
main
(
args
):
ddp_unused_params
=
DistributedDataParallelKwargs
(
find_unused_parameters
=
True
)
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
,
kwargs_handlers
=
[
ddp_unused_params
],
)
model
=
UNetModel
(
attn_resolutions
=
(
16
,),
...
...
src/diffusers/models/attention.py
View file @
d9316bf8
...
...
@@ -70,7 +70,7 @@ class AttentionBlock(nn.Module):
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
overwrite_qkv
=
overwrite_qkv
if
overwrite_qkv
:
...
...
@@ -108,15 +108,15 @@ class AttentionBlock(nn.Module):
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
_out
=
proj_out
self
.
proj
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
(
[
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj
_out
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj
_out
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
proj
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
norm
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
...
...
@@ -149,7 +149,7 @@ class AttentionBlock(nn.Module):
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
_out
(
h
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
...
...
src/diffusers/training_utils.py
View file @
d9316bf8
...
...
@@ -30,7 +30,7 @@ class EMAModel:
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
.
eval
()
self
.
averaged_model
.
requires_grad_
(
False
)
self
.
update_after_step
=
update_after_step
...
...
tests/test_modeling_utils.py
View file @
d9316bf8
...
...
@@ -52,6 +52,7 @@ from diffusers.configuration_utils import ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.bddm.pipeline_bddm
import
DiffWave
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
...
@@ -197,6 +198,20 @@ class ModelTesterMixin:
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
def
test_ema_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
ema_model
=
EMAModel
(
model
,
device
=
torch_device
)
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
ema_model
.
step
(
model
)
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment