Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ba21735c
Commit
ba21735c
authored
Jun 13, 2022
by
anton-l
Browse files
DDPM training example
parent
2d1f7de2
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
183 additions
and
50 deletions
+183
-50
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+4
-4
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-1
src/diffusers/pipelines/conversion_glide.py
src/diffusers/pipelines/conversion_glide.py
+3
-1
src/diffusers/pipelines/pipeline_bddm.py
src/diffusers/pipelines/pipeline_bddm.py
+44
-35
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+7
-7
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+6
-0
src/diffusers/trainers/training_ddpm.py
src/diffusers/trainers/training_ddpm.py
+116
-0
No files found.
src/diffusers/__init__.py
View file @
ba21735c
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
BDDMPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
BDDMPipeline
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/configuration_utils.py
View file @
ba21735c
...
@@ -225,11 +225,11 @@ class ConfigMixin:
...
@@ -225,11 +225,11 @@ class ConfigMixin:
text
=
reader
.
read
()
text
=
reader
.
read
()
return
json
.
loads
(
text
)
return
json
.
loads
(
text
)
def
__eq__
(
self
,
other
):
#
def __eq__(self, other):
return
self
.
__dict__
==
other
.
__dict__
#
return self.__dict__ == other.__dict__
def
__repr__
(
self
):
#
def __repr__(self):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
#
return f"{self.__class__.__name__} {self.to_json_string()}"
@
property
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
def
config
(
self
)
->
Dict
[
str
,
Any
]:
...
...
src/diffusers/pipelines/__init__.py
View file @
ba21735c
from
.pipeline_bddm
import
BDDMPipeline
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_glide
import
GLIDE
from
.pipeline_glide
import
GLIDE
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_bddm
import
BDDMPipeline
src/diffusers/pipelines/conversion_glide.py
View file @
ba21735c
...
@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
...
@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
superres_model
.
load_state_dict
(
ups_state_dict
,
strict
=
False
)
superres_model
.
load_state_dict
(
ups_state_dict
,
strict
=
False
)
upscale_scheduler
=
DDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
tensor_format
=
"pt"
)
upscale_scheduler
=
DDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
tensor_format
=
"pt"
)
glide
=
GLIDE
(
glide
=
GLIDE
(
text_unet
=
text2im_model
,
text_unet
=
text2im_model
,
...
...
src/diffusers/pipelines/pipeline_bddm.py
View file @
ba21735c
...
@@ -13,14 +13,16 @@
...
@@ -13,14 +13,16 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
tqdm
import
tqdm
from
..modeling_utils
import
ModelMixin
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
...
@@ -46,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
...
@@ -46,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
_embed
=
np
.
log
(
10000
)
/
(
half_dim
-
1
)
_embed
=
np
.
log
(
10000
)
/
(
half_dim
-
1
)
_embed
=
torch
.
exp
(
torch
.
arange
(
half_dim
)
*
-
_embed
).
cuda
()
_embed
=
torch
.
exp
(
torch
.
arange
(
half_dim
)
*
-
_embed
).
cuda
()
_embed
=
diffusion_steps
*
_embed
_embed
=
diffusion_steps
*
_embed
diffusion_step_embed
=
torch
.
cat
((
torch
.
sin
(
_embed
),
diffusion_step_embed
=
torch
.
cat
((
torch
.
sin
(
_embed
),
torch
.
cos
(
_embed
)),
1
)
torch
.
cos
(
_embed
)),
1
)
return
diffusion_step_embed
return
diffusion_step_embed
...
@@ -67,8 +68,7 @@ class Conv(nn.Module):
...
@@ -67,8 +68,7 @@ class Conv(nn.Module):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
dilation
=
1
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
dilation
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
padding
=
dilation
*
(
kernel_size
-
1
)
//
2
self
.
padding
=
dilation
*
(
kernel_size
-
1
)
//
2
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
dilation
=
dilation
,
padding
=
self
.
padding
)
dilation
=
dilation
,
padding
=
self
.
padding
)
self
.
conv
=
nn
.
utils
.
weight_norm
(
self
.
conv
)
self
.
conv
=
nn
.
utils
.
weight_norm
(
self
.
conv
)
nn
.
init
.
kaiming_normal_
(
self
.
conv
.
weight
)
nn
.
init
.
kaiming_normal_
(
self
.
conv
.
weight
)
...
@@ -94,8 +94,7 @@ class ZeroConv1d(nn.Module):
...
@@ -94,8 +94,7 @@ class ZeroConv1d(nn.Module):
# every residual block (named residual layer in paper)
# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
# contains one noncausal dilated conv
class
ResidualBlock
(
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
res_channels
,
skip_channels
,
dilation
,
def
__init__
(
self
,
res_channels
,
skip_channels
,
dilation
,
diffusion_step_embed_dim_out
):
diffusion_step_embed_dim_out
):
super
().
__init__
()
super
().
__init__
()
self
.
res_channels
=
res_channels
self
.
res_channels
=
res_channels
...
@@ -103,15 +102,12 @@ class ResidualBlock(nn.Module):
...
@@ -103,15 +102,12 @@ class ResidualBlock(nn.Module):
self
.
fc_t
=
nn
.
Linear
(
diffusion_step_embed_dim_out
,
self
.
res_channels
)
self
.
fc_t
=
nn
.
Linear
(
diffusion_step_embed_dim_out
,
self
.
res_channels
)
# Dilated conv layer
# Dilated conv layer
self
.
dilated_conv_layer
=
Conv
(
self
.
res_channels
,
2
*
self
.
res_channels
,
self
.
dilated_conv_layer
=
Conv
(
self
.
res_channels
,
2
*
self
.
res_channels
,
kernel_size
=
3
,
dilation
=
dilation
)
kernel_size
=
3
,
dilation
=
dilation
)
# Add mel spectrogram upsampler and conditioner conv1x1 layer
# Add mel spectrogram upsampler and conditioner conv1x1 layer
self
.
upsample_conv2d
=
nn
.
ModuleList
()
self
.
upsample_conv2d
=
nn
.
ModuleList
()
for
s
in
[
16
,
16
]:
for
s
in
[
16
,
16
]:
conv_trans2d
=
nn
.
ConvTranspose2d
(
1
,
1
,
(
3
,
2
*
s
),
conv_trans2d
=
nn
.
ConvTranspose2d
(
1
,
1
,
(
3
,
2
*
s
),
padding
=
(
1
,
s
//
2
),
stride
=
(
1
,
s
))
padding
=
(
1
,
s
//
2
),
stride
=
(
1
,
s
))
conv_trans2d
=
nn
.
utils
.
weight_norm
(
conv_trans2d
)
conv_trans2d
=
nn
.
utils
.
weight_norm
(
conv_trans2d
)
nn
.
init
.
kaiming_normal_
(
conv_trans2d
.
weight
)
nn
.
init
.
kaiming_normal_
(
conv_trans2d
.
weight
)
self
.
upsample_conv2d
.
append
(
conv_trans2d
)
self
.
upsample_conv2d
.
append
(
conv_trans2d
)
...
@@ -157,7 +153,7 @@ class ResidualBlock(nn.Module):
...
@@ -157,7 +153,7 @@ class ResidualBlock(nn.Module):
h
+=
mel_spec
h
+=
mel_spec
# Gated-tanh nonlinearity
# Gated-tanh nonlinearity
out
=
torch
.
tanh
(
h
[:,
:
self
.
res_channels
,
:])
*
torch
.
sigmoid
(
h
[:,
self
.
res_channels
:,
:])
out
=
torch
.
tanh
(
h
[:,
:
self
.
res_channels
,
:])
*
torch
.
sigmoid
(
h
[:,
self
.
res_channels
:,
:])
# Residual and skip outputs
# Residual and skip outputs
res
=
self
.
res_conv
(
out
)
res
=
self
.
res_conv
(
out
)
...
@@ -169,10 +165,16 @@ class ResidualBlock(nn.Module):
...
@@ -169,10 +165,16 @@ class ResidualBlock(nn.Module):
class
ResidualGroup
(
nn
.
Module
):
class
ResidualGroup
(
nn
.
Module
):
def
__init__
(
self
,
res_channels
,
skip_channels
,
num_res_layers
,
dilation_cycle
,
def
__init__
(
self
,
res_channels
,
skip_channels
,
num_res_layers
,
dilation_cycle
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_out
):
diffusion_step_embed_dim_out
,
):
super
().
__init__
()
super
().
__init__
()
self
.
num_res_layers
=
num_res_layers
self
.
num_res_layers
=
num_res_layers
self
.
diffusion_step_embed_dim_in
=
diffusion_step_embed_dim_in
self
.
diffusion_step_embed_dim_in
=
diffusion_step_embed_dim_in
...
@@ -185,16 +187,19 @@ class ResidualGroup(nn.Module):
...
@@ -185,16 +187,19 @@ class ResidualGroup(nn.Module):
self
.
residual_blocks
=
nn
.
ModuleList
()
self
.
residual_blocks
=
nn
.
ModuleList
()
for
n
in
range
(
self
.
num_res_layers
):
for
n
in
range
(
self
.
num_res_layers
):
self
.
residual_blocks
.
append
(
self
.
residual_blocks
.
append
(
ResidualBlock
(
res_channels
,
skip_channels
,
ResidualBlock
(
res_channels
,
skip_channels
,
dilation
=
2
**
(
n
%
dilation_cycle
),
dilation
=
2
**
(
n
%
dilation_cycle
),
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
))
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
,
)
)
def
forward
(
self
,
input_data
):
def
forward
(
self
,
input_data
):
x
,
mel_spectrogram
,
diffusion_steps
=
input_data
x
,
mel_spectrogram
,
diffusion_steps
=
input_data
# Embed diffusion step t
# Embed diffusion step t
diffusion_step_embed
=
calc_diffusion_step_embedding
(
diffusion_step_embed
=
calc_diffusion_step_embedding
(
diffusion_steps
,
self
.
diffusion_step_embed_dim_in
)
diffusion_steps
,
self
.
diffusion_step_embed_dim_in
)
diffusion_step_embed
=
swish
(
self
.
fc_t1
(
diffusion_step_embed
))
diffusion_step_embed
=
swish
(
self
.
fc_t1
(
diffusion_step_embed
))
diffusion_step_embed
=
swish
(
self
.
fc_t2
(
diffusion_step_embed
))
diffusion_step_embed
=
swish
(
self
.
fc_t2
(
diffusion_step_embed
))
...
@@ -239,20 +244,24 @@ class DiffWave(ModelMixin, ConfigMixin):
...
@@ -239,20 +244,24 @@ class DiffWave(ModelMixin, ConfigMixin):
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
,
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
,
)
)
# Initial conv1x1 with relu
# Initial conv1x1 with relu
self
.
init_conv
=
nn
.
Sequential
(
Conv
(
in_channels
,
res_channels
,
kernel_size
=
1
),
nn
.
ReLU
(
inplace
=
False
))
self
.
init_conv
=
nn
.
Sequential
(
Conv
(
in_channels
,
res_channels
,
kernel_size
=
1
),
nn
.
ReLU
(
inplace
=
False
))
# All residual layers
# All residual layers
self
.
residual_layer
=
ResidualGroup
(
res_channels
,
self
.
residual_layer
=
ResidualGroup
(
res_channels
,
skip_channels
,
skip_channels
,
num_res_layers
,
num_res_layers
,
dilation_cycle
,
dilation_cycle
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_out
)
diffusion_step_embed_dim_out
,
)
# Final conv1x1 -> relu -> zeroconv1x1
# Final conv1x1 -> relu -> zeroconv1x1
self
.
final_conv
=
nn
.
Sequential
(
Conv
(
skip_channels
,
skip_channels
,
kernel_size
=
1
),
self
.
final_conv
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
False
),
ZeroConv1d
(
skip_channels
,
out_channels
))
Conv
(
skip_channels
,
skip_channels
,
kernel_size
=
1
),
nn
.
ReLU
(
inplace
=
False
),
ZeroConv1d
(
skip_channels
,
out_channels
),
)
def
forward
(
self
,
input_data
):
def
forward
(
self
,
input_data
):
audio
,
mel_spectrogram
,
diffusion_steps
=
input_data
audio
,
mel_spectrogram
,
diffusion_steps
=
input_data
...
...
src/diffusers/pipelines/pipeline_glide.py
View file @
ba21735c
...
@@ -28,12 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
...
@@ -28,12 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
from
transformers.utils
import
ModelOutput
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
ModelOutput
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
,
)
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
...
@@ -871,7 +866,12 @@ class GLIDE(DiffusionPipeline):
...
@@ -871,7 +866,12 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
image
=
torch
.
randn
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
upscale_unet
.
in_channels
//
2
,
self
.
upscale_unet
.
resolution
,
self
.
upscale_unet
.
resolution
),
(
batch_size
,
self
.
upscale_unet
.
in_channels
//
2
,
self
.
upscale_unet
.
resolution
,
self
.
upscale_unet
.
resolution
,
),
generator
=
generator
,
generator
=
generator
,
)
)
image
=
image
.
to
(
torch_device
)
*
upsample_temp
image
=
image
.
to
(
torch_device
)
*
upsample_temp
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
ba21735c
src/diffusers/schedulers/scheduling_ddpm.py
View file @
ba21735c
...
@@ -56,6 +56,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,6 +56,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
sqrt_alphas_cumprod
=
np
.
sqrt
(
self
.
alphas_cumprod
)
self
.
sqrt_one_minus_alphas_cumprod
=
np
.
sqrt
(
1
-
self
.
alphas_cumprod
)
self
.
one
=
np
.
array
(
1.0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
...
@@ -131,5 +133,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -131,5 +133,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_image
return
pred_prev_image
def
forward_step
(
self
,
original_image
,
noise
,
t
):
noisy_image
=
self
.
sqrt_alphas_cumprod
[
t
]
*
original_image
+
self
.
sqrt_one_minus_alphas_cumprod
[
t
]
*
noise
return
noisy_image
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
timesteps
return
self
.
timesteps
src/diffusers/trainers/training_ddpm.py
0 → 100644
View file @
ba21735c
import
random
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
torchvision.transforms
import
CenterCrop
,
Compose
,
Lambda
,
RandomHorizontalFlip
,
Resize
,
ToTensor
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
def
set_seed
(
seed
):
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
set_seed
(
0
)
accelerator
=
Accelerator
(
mixed_precision
=
"fp16"
)
model
=
UNetModel
(
ch
=
128
,
ch_mult
=
(
1
,
2
,
4
,
8
),
resolution
=
64
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
lr
=
1e-4
)
num_epochs
=
100
batch_size
=
8
gradient_accumulation_steps
=
8
augmentations
=
Compose
(
[
Resize
(
64
),
CenterCrop
(
64
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
)
dataset
=
load_dataset
(
"huggan/pokemon"
,
split
=
"train"
)
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
dataset
=
dataset
.
shuffle
(
seed
=
0
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
)
lr_scheduler
=
get_linear_schedule_with_warmup
(
optimizer
=
optimizer
,
num_warmup_steps
=
1000
,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
for
epoch
in
range
(
num_epochs
):
model
.
train
()
pbar
=
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
for
idx
in
range
(
bsz
):
noise
=
torch
.
randn_like
(
clean_images
[
0
]).
to
(
clean_images
.
device
)
noisy_images
[
idx
]
=
noise_scheduler
.
forward_step
(
clean_images
[
idx
],
noise
,
timesteps
[
idx
])
if
step
%
gradient_accumulation_steps
==
0
:
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_images
,
timesteps
)
loss
=
F
.
l1_loss
(
output
,
clean_images
)
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
loss
=
F
.
l1_loss
(
output
,
clean_images
)
accelerator
.
backward
(
loss
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
optimizer
.
step
()
# eval
model
.
eval
()
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
model
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
# process image to PIL
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
(
image_processed
+
1.0
)
*
127.5
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
pipeline
.
save_pretrained
(
"./poke-ddpm"
)
image_pil
.
save
(
f
"./poke-ddpm/test_
{
epoch
}
.png"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment