Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
d9316bf8
Unverified
Commit
d9316bf8
authored
Jul 04, 2022
by
Anton Lozhkov
Committed by
GitHub
Jul 04, 2022
Browse files
Fix mutable proj_out weight in the Attention layer (#73)
* Catch unused params in DDP * Fix proj_out, add test
parent
3abf4bc4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
8 deletions
+29
-8
examples/train_unconditional.py
examples/train_unconditional.py
+8
-2
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+5
-5
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+15
-0
No files found.
examples/train_unconditional.py
View file @
d9316bf8
...
...
@@ -4,7 +4,7 @@ import os
import
torch
import
torch.nn.functional
as
F
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
,
DistributedDataParallelKwargs
from
accelerate.logging
import
get_logger
from
datasets
import
load_dataset
from
diffusers
import
DDIMPipeline
,
DDIMScheduler
,
UNetModel
...
...
@@ -27,8 +27,14 @@ logger = get_logger(__name__)
def
main
(
args
):
ddp_unused_params
=
DistributedDataParallelKwargs
(
find_unused_parameters
=
True
)
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
,
kwargs_handlers
=
[
ddp_unused_params
],
)
model
=
UNetModel
(
attn_resolutions
=
(
16
,),
...
...
src/diffusers/models/attention.py
View file @
d9316bf8
...
...
@@ -70,7 +70,7 @@ class AttentionBlock(nn.Module):
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
overwrite_qkv
=
overwrite_qkv
if
overwrite_qkv
:
...
...
@@ -108,15 +108,15 @@ class AttentionBlock(nn.Module):
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
self
.
proj
_out
=
proj_out
self
.
proj
=
proj_out
elif
self
.
overwrite_linear
:
self
.
qkv
.
weight
.
data
=
torch
.
concat
(
[
self
.
NIN_0
.
W
.
data
.
T
,
self
.
NIN_1
.
W
.
data
.
T
,
self
.
NIN_2
.
W
.
data
.
T
],
dim
=
0
)[:,
:,
None
]
self
.
qkv
.
bias
.
data
=
torch
.
concat
([
self
.
NIN_0
.
b
.
data
,
self
.
NIN_1
.
b
.
data
,
self
.
NIN_2
.
b
.
data
],
dim
=
0
)
self
.
proj
_out
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj
_out
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
proj
.
weight
.
data
=
self
.
NIN_3
.
W
.
data
.
T
[:,
:,
None
]
self
.
proj
.
bias
.
data
=
self
.
NIN_3
.
b
.
data
self
.
norm
.
weight
.
data
=
self
.
GroupNorm_0
.
weight
.
data
self
.
norm
.
bias
.
data
=
self
.
GroupNorm_0
.
bias
.
data
...
...
@@ -149,7 +149,7 @@ class AttentionBlock(nn.Module):
a
=
torch
.
einsum
(
"bts,bcs->bct"
,
weight
,
v
)
h
=
a
.
reshape
(
bs
,
-
1
,
length
)
h
=
self
.
proj
_out
(
h
)
h
=
self
.
proj
(
h
)
h
=
h
.
reshape
(
b
,
c
,
*
spatial
)
result
=
x
+
h
...
...
src/diffusers/training_utils.py
View file @
d9316bf8
...
...
@@ -30,7 +30,7 @@ class EMAModel:
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
.
eval
()
self
.
averaged_model
.
requires_grad_
(
False
)
self
.
update_after_step
=
update_after_step
...
...
tests/test_modeling_utils.py
View file @
d9316bf8
...
...
@@ -52,6 +52,7 @@ from diffusers.configuration_utils import ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.bddm.pipeline_bddm
import
DiffWave
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
...
@@ -197,6 +198,20 @@ class ModelTesterMixin:
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
def
test_ema_training
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
train
()
ema_model
=
EMAModel
(
model
,
device
=
torch_device
)
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
((
inputs_dict
[
"x"
].
shape
[
0
],)
+
self
.
output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
ema_model
.
step
(
model
)
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment