Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
abcb2597
Commit
abcb2597
authored
Jun 27, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
183056f2
c991ffd4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
181 additions
and
65 deletions
+181
-65
examples/train_unconditional.py
examples/train_unconditional.py
+17
-12
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+54
-53
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+88
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+22
-0
No files found.
examples/train_unconditional.py
View file @
abcb2597
...
@@ -9,14 +9,14 @@ from accelerate import Accelerator
...
@@ -9,14 +9,14 @@ from accelerate import Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
logging
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
InterpolationMode
,
InterpolationMode
,
Lambda
,
Normalize
,
RandomHorizontalFlip
,
RandomHorizontalFlip
,
Resize
,
Resize
,
ToTensor
,
ToTensor
,
...
@@ -48,7 +48,7 @@ def main(args):
...
@@ -48,7 +48,7 @@ def main(args):
CenterCrop
(
args
.
resolution
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
RandomHorizontalFlip
(),
ToTensor
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Normalize
([
0.5
],
[
0.5
]
),
]
]
)
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
...
@@ -71,6 +71,8 @@ def main(args):
...
@@ -71,6 +71,8 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
1.0
,
power
=
3
/
4
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -87,6 +89,7 @@ def main(args):
...
@@ -87,6 +89,7 @@ def main(args):
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
...
@@ -117,19 +120,22 @@ def main(args):
...
@@ -117,19 +120,22 @@ def main(args):
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
ema_model
.
step
(
model
,
global_step
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
],
ema_decay
=
ema_model
.
decay
)
global_step
+=
1
optimizer
.
step
()
accelerator
.
wait_for_everyone
()
if
is_distributed
:
torch
.
distributed
.
barrier
()
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
args
.
local_rank
in
[
-
1
,
0
]:
if
accelerator
.
is_main_process
:
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
pipeline
=
DDPM
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
...
@@ -151,8 +157,7 @@ def main(args):
...
@@ -151,8 +157,7 @@ def main(args):
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
is_distributed
:
accelerator
.
wait_for_everyone
()
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
src/diffusers/models/unet_ldm.py
View file @
abcb2597
...
@@ -82,61 +82,62 @@ def Normalize(in_channels):
...
@@ -82,61 +82,62 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
LinearAttention
(
nn
.
Module
):
#class LinearAttention(nn.Module):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
# def __init__(self, dim, heads=4, dim_head=32):
super
().
__init__
()
# super().__init__()
self
.
heads
=
heads
# self.heads = heads
hidden_dim
=
dim_head
*
heads
# hidden_dim = dim_head * heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
def
forward
(
self
,
x
):
# def forward(self, x):
b
,
c
,
h
,
w
=
x
.
shape
# b, c, h, w = x.shape
qkv
=
self
.
to_qkv
(
x
)
# qkv = self.to_qkv(x)
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k
=
k
.
softmax
(
dim
=-
1
)
# import ipdb; ipdb.set_trace()
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
# k = k.softmax(dim=-1)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
# context = torch.einsum("bhdn,bhen->bhde", k, v)
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
# out = torch.einsum("bhde,bhdn->bhen", context, q)
return
self
.
to_out
(
out
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
class
SpatialSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
#class SpatialSelfAttention(nn.Module):
super
().
__init__
()
# def __init__(self, in_channels):
self
.
in_channels
=
in_channels
# super().__init__()
# self.in_channels = in_channels
self
.
norm
=
Normalize
(
in_channels
)
#
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.norm = Normalize(in_channels)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def
forward
(
self
,
x
):
#
h_
=
x
# def forward(self, x):
h_
=
self
.
norm
(
h_
)
# h_ = x
q
=
self
.
q
(
h_
)
# h_ = self.norm(h_)
k
=
self
.
k
(
h_
)
# q = self.q(h_)
v
=
self
.
v
(
h_
)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
#
b, c, h, w = q.shape
q
=
rearrange
(
q
,
"b c h w -> b (h w) c"
)
#
q = rearrange(q, "b c h w -> b (h w) c")
k
=
rearrange
(
k
,
"b c h w -> b c (h w)"
)
#
k = rearrange(k, "b c h w -> b c (h w)")
w_
=
torch
.
einsum
(
"bij,bjk->bik"
,
q
,
k
)
#
w_ = torch.einsum("bij,bjk->bik", q, k)
#
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
#
w_ = w_ * (int(c) ** (-0.5))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
#
w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# attend to values
v
=
rearrange
(
v
,
"b c h w -> b c (h w)"
)
#
v = rearrange(v, "b c h w -> b c (h w)")
w_
=
rearrange
(
w_
,
"b i j -> b j i"
)
#
w_ = rearrange(w_, "b i j -> b j i")
h_
=
torch
.
einsum
(
"bij,bjk->bik"
,
v
,
w_
)
#
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_
=
rearrange
(
h_
,
"b c (h w) -> b c h w"
,
h
=
h
)
#
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_
=
self
.
proj_out
(
h_
)
#
h_ = self.proj_out(h_)
#
return
x
+
h_
#
return x + h_
#
class
CrossAttention
(
nn
.
Module
):
class
CrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
heads
=
8
,
dim_head
=
64
,
dropout
=
0.0
):
...
...
src/diffusers/training_utils.py
0 → 100644
View file @
abcb2597
import
copy
import
torch
class
EMAModel
:
"""
Exponential Moving Average of models weights
"""
def
__init__
(
self
,
model
,
update_after_step
=
0
,
inv_gamma
=
1.0
,
power
=
2
/
3
,
min_value
=
0.0
,
max_value
=
0.9999
,
device
=
None
,
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
self
.
averaged_model
.
requires_grad_
(
False
)
self
.
update_after_step
=
update_after_step
self
.
inv_gamma
=
inv_gamma
self
.
power
=
power
self
.
min_value
=
min_value
self
.
max_value
=
max_value
if
device
is
not
None
:
self
.
averaged_model
=
self
.
averaged_model
.
to
(
device
=
device
)
self
.
decay
=
0.0
def
get_decay
(
self
,
optimization_step
):
"""
Compute the decay factor for the exponential moving average.
"""
step
=
max
(
0
,
optimization_step
-
self
.
update_after_step
-
1
)
value
=
1
-
(
1
+
step
/
self
.
inv_gamma
)
**
-
self
.
power
if
step
<=
0
:
return
0.0
return
max
(
self
.
min_value
,
min
(
value
,
self
.
max_value
))
@
torch
.
no_grad
()
def
step
(
self
,
new_model
,
optimization_step
):
ema_state_dict
=
{}
ema_params
=
self
.
averaged_model
.
state_dict
()
self
.
decay
=
self
.
get_decay
(
optimization_step
)
for
key
,
param
in
new_model
.
named_parameters
():
if
isinstance
(
param
,
dict
):
continue
try
:
ema_param
=
ema_params
[
key
]
except
KeyError
:
ema_param
=
param
.
float
().
clone
()
if
param
.
ndim
==
1
else
copy
.
deepcopy
(
param
)
ema_params
[
key
]
=
ema_param
if
not
param
.
requires_grad
:
ema_params
[
key
].
copy_
(
param
.
to
(
dtype
=
ema_param
.
dtype
).
data
)
ema_param
=
ema_params
[
key
]
else
:
ema_param
.
mul_
(
self
.
decay
)
ema_param
.
add_
(
param
.
data
.
to
(
dtype
=
ema_param
.
dtype
),
alpha
=
1
-
self
.
decay
)
ema_state_dict
[
key
]
=
ema_param
for
key
,
param
in
new_model
.
named_buffers
():
ema_state_dict
[
key
]
=
param
self
.
averaged_model
.
load_state_dict
(
ema_state_dict
,
strict
=
False
)
tests/test_modeling_utils.py
View file @
abcb2597
...
@@ -510,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -510,6 +510,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
def
test_output_pretrained_spatial_transformer
(
self
):
model
=
UNetLDMModel
.
from_pretrained
(
"fusing/unet-ldm-dummy-spatial"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
context
=
torch
.
ones
((
1
,
16
,
64
),
dtype
=
torch
.
float32
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
context
=
context
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
61.3445
,
56.9005
,
29.4339
,
59.5497
,
60.7375
,
34.1719
,
48.1951
,
42.6569
,
25.0890
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
UNetGradTTSModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetGradTTSModel
model_class
=
UNetGradTTSModel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment