Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
c991ffd4
Commit
c991ffd4
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
3986741b
0e13d329
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
383 additions
and
12 deletions
+383
-12
examples/train_unconditional.py
examples/train_unconditional.py
+17
-12
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+278
-0
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+88
-0
No files found.
examples/train_unconditional.py
View file @
c991ffd4
...
...
@@ -9,14 +9,14 @@ from accelerate import Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
CenterCrop
,
Compose
,
InterpolationMode
,
Lambda
,
Normalize
,
RandomHorizontalFlip
,
Resize
,
ToTensor
,
...
...
@@ -48,7 +48,7 @@ def main(args):
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Normalize
([
0.5
],
[
0.5
]
),
]
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
...
...
@@ -71,6 +71,8 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
1.0
,
power
=
3
/
4
)
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
...
@@ -87,6 +89,7 @@ def main(args):
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
...
...
@@ -117,19 +120,22 @@ def main(args):
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
lr_scheduler
.
step
()
ema_model
.
step
(
model
,
global_step
)
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
],
ema_decay
=
ema_model
.
decay
)
global_step
+=
1
optimizer
.
step
()
if
is_distributed
:
torch
.
distributed
.
barrier
()
accelerator
.
wait_for_everyone
()
# Generate a sample image for visual inspection
if
args
.
local_rank
in
[
-
1
,
0
]:
model
.
eval
()
if
accelerator
.
is_main_process
:
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
pipeline
=
DDPM
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
...
...
@@ -151,8 +157,7 @@ def main(args):
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
is_distributed
:
torch
.
distributed
.
barrier
()
accelerator
.
wait_for_everyone
()
if
__name__
==
"__main__"
:
...
...
src/diffusers/models/resnet.py
View file @
c991ffd4
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
avg_pool_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if
dims
==
1
:
return
nn
.
AvgPool1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
AvgPool2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
AvgPool3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
ConvTranspose1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
ConvTranspose2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
ConvTranspose3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
def
nonlinearity
(
x
,
swish
=
1.0
):
# swish
if
swish
==
1.0
:
return
F
.
silu
(
x
)
else
:
return
x
*
F
.
sigmoid
(
x
*
float
(
swish
))
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
use_conv_transpose
=
use_conv_transpose
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
self
.
padding
=
padding
stride
=
2
if
dims
!=
3
else
(
1
,
2
,
2
)
if
use_conv
:
self
.
down
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
stride
=
stride
,
padding
=
padding
)
else
:
assert
self
.
channels
==
self
.
out_channels
self
.
down
=
avg_pool_nd
(
dims
,
kernel_size
=
stride
,
stride
=
stride
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv
and
self
.
padding
==
0
and
self
.
dims
==
2
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
F
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
return
self
.
down
(
x
)
class
UNetUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GlideUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
LDMUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
conv
=
nn
.
ConvTranspose1d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
# class ResnetBlock(nn.Module):
# def __init__(
# self,
# *,
# in_channels,
# out_channels=None,
# conv_shortcut=False,
# dropout,
# temb_channels=512,
# use_scale_shift_norm=False,
# ):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
# self.use_scale_shift_norm = use_scale_shift_norm
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
# # TODO: check if this broadcasting works correctly for 1D and 3D
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(temb, 2, dim=1)
# h = self.norm2(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + temb
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
# return x + h
src/diffusers/training_utils.py
0 → 100644
View file @
c991ffd4
import
copy
import
torch
class
EMAModel
:
"""
Exponential Moving Average of models weights
"""
def
__init__
(
self
,
model
,
update_after_step
=
0
,
inv_gamma
=
1.0
,
power
=
2
/
3
,
min_value
=
0.0
,
max_value
=
0.9999
,
device
=
None
,
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self
.
averaged_model
=
copy
.
deepcopy
(
model
)
self
.
averaged_model
.
requires_grad_
(
False
)
self
.
update_after_step
=
update_after_step
self
.
inv_gamma
=
inv_gamma
self
.
power
=
power
self
.
min_value
=
min_value
self
.
max_value
=
max_value
if
device
is
not
None
:
self
.
averaged_model
=
self
.
averaged_model
.
to
(
device
=
device
)
self
.
decay
=
0.0
def
get_decay
(
self
,
optimization_step
):
"""
Compute the decay factor for the exponential moving average.
"""
step
=
max
(
0
,
optimization_step
-
self
.
update_after_step
-
1
)
value
=
1
-
(
1
+
step
/
self
.
inv_gamma
)
**
-
self
.
power
if
step
<=
0
:
return
0.0
return
max
(
self
.
min_value
,
min
(
value
,
self
.
max_value
))
@
torch
.
no_grad
()
def
step
(
self
,
new_model
,
optimization_step
):
ema_state_dict
=
{}
ema_params
=
self
.
averaged_model
.
state_dict
()
self
.
decay
=
self
.
get_decay
(
optimization_step
)
for
key
,
param
in
new_model
.
named_parameters
():
if
isinstance
(
param
,
dict
):
continue
try
:
ema_param
=
ema_params
[
key
]
except
KeyError
:
ema_param
=
param
.
float
().
clone
()
if
param
.
ndim
==
1
else
copy
.
deepcopy
(
param
)
ema_params
[
key
]
=
ema_param
if
not
param
.
requires_grad
:
ema_params
[
key
].
copy_
(
param
.
to
(
dtype
=
ema_param
.
dtype
).
data
)
ema_param
=
ema_params
[
key
]
else
:
ema_param
.
mul_
(
self
.
decay
)
ema_param
.
add_
(
param
.
data
.
to
(
dtype
=
ema_param
.
dtype
),
alpha
=
1
-
self
.
decay
)
ema_state_dict
[
key
]
=
ema_param
for
key
,
param
in
new_model
.
named_buffers
():
ema_state_dict
[
key
]
=
param
self
.
averaged_model
.
load_state_dict
(
ema_state_dict
,
strict
=
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment