Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
11631e81
Commit
11631e81
authored
Jun 13, 2022
by
Patrick von Platen
Browse files
merge
parents
13c5a065
b8a67640
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
371 additions
and
121 deletions
+371
-121
README.md
README.md
+85
-46
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+4
-4
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/conversion_glide.py
src/diffusers/pipelines/conversion_glide.py
+3
-1
src/diffusers/pipelines/pipeline_bddm.py
src/diffusers/pipelines/pipeline_bddm.py
+113
-36
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+17
-23
src/diffusers/schedulers/glide_ddim.py
src/diffusers/schedulers/glide_ddim.py
+0
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+17
-9
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+14
-1
src/diffusers/trainers/training_ddpm.py
src/diffusers/trainers/training_ddpm.py
+116
-0
No files found.
README.md
View file @
11631e81
...
...
@@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil
.
save
(
"test.png"
)
```
**Text to Image generation with Latent Diffusion**
####
**Text to Image generation with Latent Diffusion**
```
python
from
diffusers
import
DiffusionPipeline
...
...
@@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0])
# save image
image_pil
.
save
(
"test.png"
)
```
####
**Text to speech with BDDM**
_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_
deeplearningexamples_tacotron2/) to load tacotron2 model._
```
python
import
torch
from
diffusers
import
BDDM
,
DiffusionPipeline
torch_device
=
"cuda"
# load the BDDM pipeline
bddm
=
DiffusionPipeline
.
from_pretrained
(
"fusing/diffwave-vocoder"
)
# load tacotron2 to get the mel spectograms
tacotron2
=
torch
.
hub
.
load
(
'NVIDIA/DeepLearningExamples:torchhub'
,
'nvidia_tacotron2'
,
model_math
=
'fp16'
)
tacotron2
=
tacotron2
.
to
(
torch_device
).
eval
()
text
=
"Hello world, I missed you so much."
utils
=
torch
.
hub
.
load
(
'NVIDIA/DeepLearningExamples:torchhub'
,
'nvidia_tts_utils'
)
sequences
,
lengths
=
utils
.
prepare_input_sequence
([
text
])
# generate mel spectograms using text
with
torch
.
no_grad
():
mel_spec
,
_
,
_
=
tacotron2
.
infer
(
sequences
,
lengths
)
# generate the speech by passing mel spectograms to BDDM pipeline
generator
=
torch
.
manual_seed
(
0
)
audio
=
bddm
(
mel_spec
,
generator
,
torch_device
)
# save generated audio
from
scipy.io.wavfile
import
write
as
wavwrite
sampling_rate
=
22050
wavwrite
(
"generated_audio.wav"
,
sampling_rate
,
audio
.
squeeze
().
cpu
().
numpy
())
```
## Library structure:
```
├── models
│ ├── audio
│ │ └── fastdiff
│ │ ├── modeling_fastdiff.py
│ │ ├── README.md
│ │ └── run_fastdiff.py
│ ├── __init__.py
│ └── vision
│ ├── dalle2
│ │ ├── modeling_dalle2.py
│ │ ├── README.md
│ │ └── run_dalle2.py
│ ├── ddpm
│ │ ├── example.py
│ │ ├── modeling_ddpm.py
│ │ ├── README.md
│ │ └── run_ddpm.py
│ ├── glide
│ │ ├── modeling_glide.py
│ │ ├── modeling_vqvae.py.py
│ │ ├── README.md
│ │ └── run_glide.py
│ ├── imagen
│ │ ├── modeling_dalle2.py
│ │ ├── README.md
│ │ └── run_dalle2.py
│ ├── __init__.py
│ └── latent_diffusion
│ ├── modeling_latent_diffusion.py
│ ├── README.md
│ └── run_latent_diffusion.py
├── pyproject.toml
├── LICENSE
├── Makefile
├── README.md
├── pyproject.toml
├── setup.cfg
├── setup.py
├── src
│ └── diffusers
│ ├── configuration_utils.py
│ ├── __init__.py
│ ├── modeling_utils.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── unet_glide.py
│ │ └── unet.py
│ ├── pipeline_utils.py
│ └── schedulers
│ ├── gaussian_ddpm.py
│ ├── __init__.py
│ ├── diffusers
│ ├── __init__.py
│ ├── configuration_utils.py
│ ├── dependency_versions_check.py
│ ├── dependency_versions_table.py
│ ├── dynamic_modules_utils.py
│ ├── modeling_utils.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── unet.py
│ │ ├── unet_glide.py
│ │ └── unet_ldm.py
│ ├── pipeline_utils.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── configuration_ldmbert.py
│ │ ├── conversion_glide.py
│ │ ├── modeling_vae.py
│ │ ├── pipeline_bddm.py
│ │ ├── pipeline_ddim.py
│ │ ├── pipeline_ddpm.py
│ │ ├── pipeline_glide.py
│ │ └── pipeline_latent_diffusion.py
│ ├── schedulers
│ │ ├── __init__.py
│ │ ├── classifier_free_guidance.py
│ │ ├── scheduling_ddim.py
│ │ ├── scheduling_ddpm.py
│ │ ├── scheduling_plms.py
│ │ └── scheduling_utils.py
│ ├── testing_utils.py
│ └── utils
│ ├── __init__.py
│ └── logging.py
├── tests
│ └── test_modeling_utils.py
│ ├── __init__.py
│ ├── test_modeling_utils.py
│ └── test_scheduler.py
└── utils
├── check_config_docstrings.py
├── check_copies.py
├── check_dummies.py
├── check_inits.py
├── check_repo.py
├── check_table.py
└── check_tf_ops.py
```
src/diffusers/__init__.py
View file @
11631e81
...
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/configuration_utils.py
View file @
11631e81
...
...
@@ -225,11 +225,11 @@ class ConfigMixin:
text
=
reader
.
read
()
return
json
.
loads
(
text
)
def
__eq__
(
self
,
other
):
return
self
.
__dict__
==
other
.
__dict__
#
def __eq__(self, other):
#
return self.__dict__ == other.__dict__
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
#
def __repr__(self):
#
return f"{self.__class__.__name__} {self.to_json_string()}"
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
...
...
src/diffusers/pipelines/__init__.py
View file @
11631e81
...
...
@@ -3,3 +3,4 @@ from .pipeline_ddpm import DDPM
from
.pipeline_pndm
import
PNDM
from
.pipeline_glide
import
GLIDE
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_bddm
import
BDDM
src/diffusers/pipelines/conversion_glide.py
View file @
11631e81
...
...
@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
superres_model
.
load_state_dict
(
ups_state_dict
,
strict
=
False
)
upscale_scheduler
=
DDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
,
beta_start
=
0.0001
,
beta_end
=
0.02
)
upscale_scheduler
=
DDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
tensor_format
=
"pt"
)
glide
=
GLIDE
(
text_unet
=
text2im_model
,
...
...
src/diffusers/pipelines/pipeline_bddm.py
View file @
11631e81
...
...
@@ -13,11 +13,18 @@
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..pipeline_utils
import
DiffusionPipeline
def
calc_diffusion_step_embedding
(
diffusion_steps
,
diffusion_step_embed_dim_in
):
"""
...
...
@@ -41,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
_embed
=
np
.
log
(
10000
)
/
(
half_dim
-
1
)
_embed
=
torch
.
exp
(
torch
.
arange
(
half_dim
)
*
-
_embed
).
cuda
()
_embed
=
diffusion_steps
*
_embed
diffusion_step_embed
=
torch
.
cat
((
torch
.
sin
(
_embed
),
torch
.
cos
(
_embed
)),
1
)
diffusion_step_embed
=
torch
.
cat
((
torch
.
sin
(
_embed
),
torch
.
cos
(
_embed
)),
1
)
return
diffusion_step_embed
...
...
@@ -62,8 +68,7 @@ class Conv(nn.Module):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
dilation
=
1
):
super
().
__init__
()
self
.
padding
=
dilation
*
(
kernel_size
-
1
)
//
2
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
dilation
=
dilation
,
padding
=
self
.
padding
)
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
dilation
=
dilation
,
padding
=
self
.
padding
)
self
.
conv
=
nn
.
utils
.
weight_norm
(
self
.
conv
)
nn
.
init
.
kaiming_normal_
(
self
.
conv
.
weight
)
...
...
@@ -89,8 +94,7 @@ class ZeroConv1d(nn.Module):
# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
res_channels
,
skip_channels
,
dilation
,
diffusion_step_embed_dim_out
):
def
__init__
(
self
,
res_channels
,
skip_channels
,
dilation
,
diffusion_step_embed_dim_out
):
super
().
__init__
()
self
.
res_channels
=
res_channels
...
...
@@ -98,15 +102,12 @@ class ResidualBlock(nn.Module):
self
.
fc_t
=
nn
.
Linear
(
diffusion_step_embed_dim_out
,
self
.
res_channels
)
# Dilated conv layer
self
.
dilated_conv_layer
=
Conv
(
self
.
res_channels
,
2
*
self
.
res_channels
,
kernel_size
=
3
,
dilation
=
dilation
)
self
.
dilated_conv_layer
=
Conv
(
self
.
res_channels
,
2
*
self
.
res_channels
,
kernel_size
=
3
,
dilation
=
dilation
)
# Add mel spectrogram upsampler and conditioner conv1x1 layer
self
.
upsample_conv2d
=
nn
.
ModuleList
()
for
s
in
[
16
,
16
]:
conv_trans2d
=
nn
.
ConvTranspose2d
(
1
,
1
,
(
3
,
2
*
s
),
padding
=
(
1
,
s
//
2
),
stride
=
(
1
,
s
))
conv_trans2d
=
nn
.
ConvTranspose2d
(
1
,
1
,
(
3
,
2
*
s
),
padding
=
(
1
,
s
//
2
),
stride
=
(
1
,
s
))
conv_trans2d
=
nn
.
utils
.
weight_norm
(
conv_trans2d
)
nn
.
init
.
kaiming_normal_
(
conv_trans2d
.
weight
)
self
.
upsample_conv2d
.
append
(
conv_trans2d
)
...
...
@@ -152,7 +153,7 @@ class ResidualBlock(nn.Module):
h
+=
mel_spec
# Gated-tanh nonlinearity
out
=
torch
.
tanh
(
h
[:,
:
self
.
res_channels
,
:])
*
torch
.
sigmoid
(
h
[:,
self
.
res_channels
:,
:])
out
=
torch
.
tanh
(
h
[:,
:
self
.
res_channels
,
:])
*
torch
.
sigmoid
(
h
[:,
self
.
res_channels
:,
:])
# Residual and skip outputs
res
=
self
.
res_conv
(
out
)
...
...
@@ -164,10 +165,16 @@ class ResidualBlock(nn.Module):
class
ResidualGroup
(
nn
.
Module
):
def
__init__
(
self
,
res_channels
,
skip_channels
,
num_res_layers
,
dilation_cycle
,
def
__init__
(
self
,
res_channels
,
skip_channels
,
num_res_layers
,
dilation_cycle
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_out
):
diffusion_step_embed_dim_out
,
):
super
().
__init__
()
self
.
num_res_layers
=
num_res_layers
self
.
diffusion_step_embed_dim_in
=
diffusion_step_embed_dim_in
...
...
@@ -180,16 +187,19 @@ class ResidualGroup(nn.Module):
self
.
residual_blocks
=
nn
.
ModuleList
()
for
n
in
range
(
self
.
num_res_layers
):
self
.
residual_blocks
.
append
(
ResidualBlock
(
res_channels
,
skip_channels
,
ResidualBlock
(
res_channels
,
skip_channels
,
dilation
=
2
**
(
n
%
dilation_cycle
),
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
))
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
,
)
)
def
forward
(
self
,
input_data
):
x
,
mel_spectrogram
,
diffusion_steps
=
input_data
# Embed diffusion step t
diffusion_step_embed
=
calc_diffusion_step_embedding
(
diffusion_steps
,
self
.
diffusion_step_embed_dim_in
)
diffusion_step_embed
=
calc_diffusion_step_embedding
(
diffusion_steps
,
self
.
diffusion_step_embed_dim_in
)
diffusion_step_embed
=
swish
(
self
.
fc_t1
(
diffusion_step_embed
))
diffusion_step_embed
=
swish
(
self
.
fc_t2
(
diffusion_step_embed
))
...
...
@@ -206,27 +216,52 @@ class ResidualGroup(nn.Module):
return
skip
*
math
.
sqrt
(
1.0
/
self
.
num_res_layers
)
class
DiffWave
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
res_channels
,
skip_channels
,
out_channels
,
num_res_layers
,
dilation_cycle
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_out
):
class
DiffWave
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
in_channels
=
1
,
res_channels
=
128
,
skip_channels
=
128
,
out_channels
=
1
,
num_res_layers
=
30
,
dilation_cycle
=
10
,
diffusion_step_embed_dim_in
=
128
,
diffusion_step_embed_dim_mid
=
512
,
diffusion_step_embed_dim_out
=
512
,
):
super
().
__init__
()
# register all init arguments with self.register
self
.
register
(
in_channels
=
in_channels
,
res_channels
=
res_channels
,
skip_channels
=
skip_channels
,
out_channels
=
out_channels
,
num_res_layers
=
num_res_layers
,
dilation_cycle
=
dilation_cycle
,
diffusion_step_embed_dim_in
=
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_mid
=
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_out
=
diffusion_step_embed_dim_out
,
)
# Initial conv1x1 with relu
self
.
init_conv
=
nn
.
Sequential
(
Conv
(
in_channels
,
res_channels
,
kernel_size
=
1
),
nn
.
ReLU
(
inplace
=
False
))
# All residual layers
self
.
residual_layer
=
ResidualGroup
(
res_channels
,
self
.
residual_layer
=
ResidualGroup
(
res_channels
,
skip_channels
,
num_res_layers
,
dilation_cycle
,
diffusion_step_embed_dim_in
,
diffusion_step_embed_dim_mid
,
diffusion_step_embed_dim_out
)
diffusion_step_embed_dim_out
,
)
# Final conv1x1 -> relu -> zeroconv1x1
self
.
final_conv
=
nn
.
Sequential
(
Conv
(
skip_channels
,
skip_channels
,
kernel_size
=
1
),
nn
.
ReLU
(
inplace
=
False
),
ZeroConv1d
(
skip_channels
,
out_channels
))
self
.
final_conv
=
nn
.
Sequential
(
Conv
(
skip_channels
,
skip_channels
,
kernel_size
=
1
),
nn
.
ReLU
(
inplace
=
False
),
ZeroConv1d
(
skip_channels
,
out_channels
),
)
def
forward
(
self
,
input_data
):
audio
,
mel_spectrogram
,
diffusion_steps
=
input_data
...
...
@@ -234,3 +269,45 @@ class DiffWave(nn.Module):
x
=
self
.
init_conv
(
x
).
clone
()
x
=
self
.
residual_layer
((
x
,
mel_spectrogram
,
diffusion_steps
))
return
self
.
final_conv
(
x
)
class
BDDM
(
DiffusionPipeline
):
def
__init__
(
self
,
diffwave
,
noise_scheduler
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
diffwave
=
diffwave
,
noise_scheduler
=
noise_scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
mel_spectrogram
,
generator
,
torch_device
=
None
):
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
diffwave
.
to
(
torch_device
)
mel_spectrogram
=
mel_spectrogram
.
to
(
torch_device
)
audio_length
=
mel_spectrogram
.
size
(
-
1
)
*
256
audio_size
=
(
1
,
1
,
audio_length
)
# Sample gaussian noise to begin loop
audio
=
torch
.
normal
(
0
,
1
,
size
=
audio_size
,
generator
=
generator
).
to
(
torch_device
)
timestep_values
=
self
.
noise_scheduler
.
timestep_values
num_prediction_steps
=
len
(
self
.
noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
# 1. predict noise residual
ts
=
(
torch
.
tensor
(
timestep_values
[
t
])
*
torch
.
ones
((
1
,
1
))).
to
(
torch_device
)
residual
=
self
.
diffwave
((
audio
,
mel_spectrogram
,
ts
))
# 2. predict previous mean of audio x_t-1
pred_prev_audio
=
self
.
noise_scheduler
.
step
(
residual
,
audio
,
t
)
# 3. optionally sample variance
variance
=
0
if
t
>
0
:
noise
=
torch
.
normal
(
0
,
1
,
size
=
audio_size
,
generator
=
generator
).
to
(
torch_device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current audio to prev_audio: x_t -> x_t-1
audio
=
pred_prev_audio
+
variance
return
audio
src/diffusers/pipelines/pipeline_glide.py
View file @
11631e81
...
...
@@ -28,13 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
,
)
from
transformers.utils
import
ModelOutput
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..pipeline_utils
import
DiffusionPipeline
...
...
@@ -872,31 +866,31 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
(
batch_size
,
self
.
upscale_unet
.
in_channels
//
2
,
self
.
upscale_unet
.
resolution
,
self
.
upscale_unet
.
resolution
,
),
generator
=
generator
,
)
image
=
image
.
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
image
=
image
.
to
(
torch_device
)
*
upsample_temp
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
# adapt the beta schedule to the number of steps
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
# 1. predict noise residual
with
torch
.
no_grad
():
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
time_input
=
torch
.
tensor
([
inference_step_times
[
t
]
]
*
image
.
shape
[
0
],
device
=
torch_device
)
model_output
=
self
.
upscale_unet
(
image
,
time_input
,
low_res
)
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
upscale_noise_scheduler
.
step
(
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
,
use_clipped_residual
=
True
)
# 3. optionally sample variance
...
...
@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
image
.
permute
(
0
,
2
,
3
,
1
)
image
=
image
.
clamp
(
-
1
,
1
).
permute
(
0
,
2
,
3
,
1
)
return
image
src/diffusers/schedulers/glide_ddim.py
deleted
100644 → 0
View file @
13c5a065
src/diffusers/schedulers/scheduling_ddim.py
View file @
11631e81
...
...
@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
trained_betas
=
None
,
timestep_values
=
None
,
clip_predicted_image
=
True
,
tensor_format
=
"np"
,
):
...
...
@@ -37,6 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule
=
beta_schedule
,
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
clip_image
=
clip_predicted_image
if
beta_schedule
==
"linear"
:
...
...
@@ -69,14 +72,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def
rescale_betas
(
self
,
num_timesteps
):
if
self
.
beta_schedule
==
"linear"
:
scale
=
self
.
timesteps
/
num_timesteps
self
.
betas
=
linear_beta_schedule
(
num_timesteps
,
beta_start
=
self
.
beta_start
*
scale
,
beta_end
=
self
.
beta_end
*
scale
)
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
# def rescale_betas(self, num_timesteps):
# # GLIDE scaling
# if self.beta_schedule == "linear":
# scale = self.timesteps / num_timesteps
# self.betas = linear_beta_schedule(
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
# )
# self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
...
...
@@ -107,7 +111,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
):
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
,
use_clipped_residual
=
False
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
...
...
@@ -141,6 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
if
use_clipped_residual
:
# the residual is always re-derived from the clipped x_0 in GLIDE
residual
=
(
image
-
alpha_prod_t
**
(
0.5
)
*
pred_original_image
)
/
beta_prod_t
**
(
0.5
)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
11631e81
...
...
@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
trained_betas
=
None
,
timestep_values
=
None
,
variance_type
=
"fixed_small"
,
clip_predicted_image
=
True
,
tensor_format
=
"np"
,
...
...
@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
trained_betas
=
trained_betas
,
timestep_values
=
timestep_values
,
variance_type
=
variance_type
,
clip_predicted_image
=
clip_predicted_image
,
)
self
.
timesteps
=
int
(
timesteps
)
self
.
timestep_values
=
timestep_values
# save the fixed timestep values for BDDM
self
.
clip_image
=
clip_predicted_image
self
.
variance_type
=
variance_type
if
beta_schedule
==
"linear"
:
if
trained_betas
is
not
None
:
self
.
betas
=
np
.
asarray
(
trained_betas
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
...
...
@@ -56,6 +63,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
sqrt_alphas_cumprod
=
np
.
sqrt
(
self
.
alphas_cumprod
)
self
.
sqrt_one_minus_alphas_cumprod
=
np
.
sqrt
(
1
-
self
.
alphas_cumprod
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
...
...
@@ -131,5 +140,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_image
def
forward_step
(
self
,
original_image
,
noise
,
t
):
noisy_image
=
self
.
sqrt_alphas_cumprod
[
t
]
*
original_image
+
self
.
sqrt_one_minus_alphas_cumprod
[
t
]
*
noise
return
noisy_image
def
__len__
(
self
):
return
self
.
timesteps
src/diffusers/trainers/training_ddpm.py
0 → 100644
View file @
11631e81
import
random
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
torchvision.transforms
import
CenterCrop
,
Compose
,
Lambda
,
RandomHorizontalFlip
,
Resize
,
ToTensor
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
def
set_seed
(
seed
):
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
set_seed
(
0
)
accelerator
=
Accelerator
(
mixed_precision
=
"fp16"
)
model
=
UNetModel
(
ch
=
128
,
ch_mult
=
(
1
,
2
,
4
,
8
),
resolution
=
64
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
lr
=
1e-4
)
num_epochs
=
100
batch_size
=
8
gradient_accumulation_steps
=
8
augmentations
=
Compose
(
[
Resize
(
64
),
CenterCrop
(
64
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
)
dataset
=
load_dataset
(
"huggan/pokemon"
,
split
=
"train"
)
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
dataset
=
dataset
.
shuffle
(
seed
=
0
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
)
lr_scheduler
=
get_linear_schedule_with_warmup
(
optimizer
=
optimizer
,
num_warmup_steps
=
1000
,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
for
epoch
in
range
(
num_epochs
):
model
.
train
()
pbar
=
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
for
idx
in
range
(
bsz
):
noise
=
torch
.
randn_like
(
clean_images
[
0
]).
to
(
clean_images
.
device
)
noisy_images
[
idx
]
=
noise_scheduler
.
forward_step
(
clean_images
[
idx
],
noise
,
timesteps
[
idx
])
if
step
%
gradient_accumulation_steps
==
0
:
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_images
,
timesteps
)
loss
=
F
.
l1_loss
(
output
,
clean_images
)
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
loss
=
F
.
l1_loss
(
output
,
clean_images
)
accelerator
.
backward
(
loss
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
optimizer
.
step
()
# eval
model
.
eval
()
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
model
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
# process image to PIL
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
(
image_processed
+
1.0
)
*
127.5
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
pipeline
.
save_pretrained
(
"./poke-ddpm"
)
image_pil
.
save
(
f
"./poke-ddpm/test_
{
epoch
}
.png"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment