Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
433cb3f8
Commit
433cb3f8
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
clean up sde ve more
parent
de810814
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
120 additions
and
58 deletions
+120
-58
README.md
README.md
+24
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-2
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+2
-1
src/diffusers/pipelines/pipeline_score_sde_ve.py
src/diffusers/pipelines/pipeline_score_sde_ve.py
+38
-38
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+26
-16
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-0
No files found.
README.md
View file @
433cb3f8
...
@@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0])
...
@@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil
.
save
(
"test.png"
)
image_pil
.
save
(
"test.png"
)
```
```
#### **Example 1024x1024 image generation with SDE VE**
See
[
paper
](
https://arxiv.org/abs/2011.13456
)
for more information on SDE VE.
```
python
from
diffusers
import
DiffusionPipeline
import
torch
import
PIL.Image
torch
.
manual_seed
(
32
)
score_sde_sv
=
DiffusionPipeline
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
)
# Note this might take up to 3 minutes on a GPU
image
=
score_sde_sv
(
num_inference_steps
=
2000
)
image
=
image
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
image
=
np
.
clip
(
image
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image
[
0
])
# save image
image_pil
.
save
(
"test.png"
)
```
#### **Text to Image generation with Latent Diffusion**
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
...
...
src/diffusers/__init__.py
View file @
433cb3f8
...
@@ -9,8 +9,15 @@ __version__ = "0.0.4"
...
@@ -9,8 +9,15 @@ __version__ = "0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
,
VeSdeScheduler
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
)
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/pipelines/__init__.py
View file @
433cb3f8
...
@@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline
...
@@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline
from
.pipeline_ddim
import
DDIMPipeline
from
.pipeline_ddim
import
DDIMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_score_sde_ve
import
ScoreSdeVePipeline
# from .pipeline_score_sde import
NCSNpp
Pipeline
# from .pipeline_score_sde import
ScoreSdeVe
Pipeline
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/pipelines/pipeline_score_sde.py
→
src/diffusers/pipelines/pipeline_score_sde
_ve
.py
View file @
433cb3f8
...
@@ -6,51 +6,44 @@ import PIL
...
@@ -6,51 +6,44 @@ import PIL
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
class
NCSNpp
Pipeline
(
DiffusionPipeline
):
class
ScoreSdeVe
Pipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
def
__call__
(
self
,
generator
=
None
):
def
__call__
(
self
,
num_inference_steps
=
2000
,
generator
=
None
):
N
=
self
.
scheduler
.
config
.
N
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
img_size
=
self
.
model
.
config
.
image_size
channels
=
self
.
model
.
config
.
num_channels
channels
=
self
.
model
.
config
.
num_channels
shape
=
(
1
,
channels
,
img_size
,
img_size
)
shape
=
(
1
,
channels
,
img_size
,
img_size
)
model
=
torch
.
nn
.
DataParallel
(
self
.
model
.
to
(
device
)
)
model
=
self
.
model
.
to
(
device
)
centered
=
False
centered
=
False
n_steps
=
1
n_steps
=
1
# Initial sample
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
x
.
to
(
device
)
x
=
x
.
to
(
device
)
for
i
in
range
(
N
):
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
sigma_t
=
self
.
scheduler
.
get_sigma_t
(
i
)
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
self
.
scheduler
.
set_sigmas
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
self
.
scheduler
.
timesteps
):
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
for
_
in
range
(
n_steps
):
for
_
in
range
(
n_steps
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
result
=
self
.
model
(
x
,
sigma_t
)
x
=
self
.
scheduler
.
step_correct
(
result
,
x
)
x
=
self
.
scheduler
.
step_correct
(
result
,
x
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
result
=
model
(
x
,
sigma_t
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
i
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
t
)
x
=
x_mean
x
=
x_mean
...
@@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline):
...
@@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline):
return
x
return
x
pipeline
=
NCSNppPipeline
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
)
# from configs.ve import ffhq_ncsnpp_continuous as configs
x
=
pipeline
()
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
# x = pipeline(num_inference_steps=2)
# for 5 cifar10
# for 5 cifar10
# x_sum = 106071.9922
# x_sum = 106071.9922
...
@@ -73,22 +73,22 @@ x = pipeline()
...
@@ -73,22 +73,22 @@ x = pipeline()
# x_mean = 0.1504
# x_mean = 0.1504
# for N=2 for 1024
# for N=2 for 1024
x_sum
=
3382810112.0
#
x_sum = 3382810112.0
x_mean
=
1075.366455078125
#
x_mean = 1075.366455078125
#
#
def
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
):
#
def check_x_sum_x_mean(x, x_sum, x_mean):
assert
(
x
.
abs
().
sum
()
-
x_sum
).
abs
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
#
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
#
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#
#
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
#
check_x_sum_x_mean(x, x_sum, x_mean)
#
#
def
save_image
(
x
):
#
def save_image(x):
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
#
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
#
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil
.
save
(
"../images/hey.png"
)
#
image_pil.save("../images/hey.png")
#
#
# save_image(x)
# save_image(x)
src/diffusers/schedulers/__init__.py
View file @
433cb3f8
...
@@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
...
@@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_
ve_sd
e
import
V
eSdeScheduler
from
.scheduling_
sde_v
e
import
Scor
eSde
Ve
Scheduler
src/diffusers/schedulers/scheduling_
ve_sd
e.py
→
src/diffusers/schedulers/scheduling_
sde_v
e.py
View file @
433cb3f8
# Copyright 2022
UC Berkely Team
and The HuggingFace Team. All rights reserved.
# Copyright 2022
Google Brain
and The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
class
V
eSdeScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
Scor
eSde
Ve
Scheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
snr
=
0.15
,
sigma_min
=
0.01
,
sigma_max
=
1348
,
N
=
2
,
sampling_eps
=
1e-5
,
tensor_format
=
"np"
):
def
__init__
(
self
,
snr
=
0.15
,
sigma_min
=
0.01
,
sigma_max
=
1348
,
N
=
2
,
sampling_eps
=
1e-5
,
tensor_format
=
"np"
):
super
().
__init__
()
super
().
__init__
()
self
.
register_to_config
(
self
.
register_to_config
(
...
@@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
N
=
N
,
N
=
N
,
sampling_eps
=
sampling_eps
,
sampling_eps
=
sampling_eps
,
)
)
# (PVP) - clean up with .config.
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
self
.
snr
=
snr
self
.
N
=
N
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
sigma_min
),
np
.
log
(
self
.
sigma_max
),
N
))
self
.
timesteps
=
torch
.
linspace
(
1
,
sampling_eps
,
N
)
def
get_sigma_t
(
self
,
t
):
self
.
sigmas
=
None
return
self
.
sigma_min
*
(
self
.
sigma_max
/
self
.
sigma_min
)
**
self
.
timesteps
[
t
]
self
.
discrete_sigmas
=
None
self
.
timesteps
=
None
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
timesteps
=
torch
.
linspace
(
1
,
self
.
config
.
sampling_eps
,
num_inference_steps
)
def
set_sigmas
(
self
,
num_inference_steps
):
if
self
.
timesteps
is
None
:
self
.
set_timesteps
(
num_inference_steps
)
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
config
.
sigma_min
),
np
.
log
(
self
.
config
.
sigma_max
),
num_inference_steps
)
)
self
.
sigmas
=
torch
.
tensor
(
[
self
.
config
.
sigma_min
*
(
self
.
config
.
sigma_max
/
self
.
sigma_min
)
**
t
for
t
in
self
.
timesteps
]
)
def
step_pred
(
self
,
result
,
x
,
t
):
def
step_pred
(
self
,
result
,
x
,
t
):
t
=
self
.
timesteps
[
t
]
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
t
=
t
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
timestep
=
(
t
*
(
2
-
1
)).
long
()
timestep
=
(
t
*
(
self
.
N
-
1
)).
long
()
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
adjacent_sigma
=
torch
.
where
(
adjacent_sigma
=
torch
.
where
(
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
.
device
)
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
imestep
.
device
)
)
)
f
=
torch
.
zeros_like
(
x
)
f
=
torch
.
zeros_like
(
x
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
...
@@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
noise
=
torch
.
randn_like
(
x
)
noise
=
torch
.
randn_like
(
x
)
grad_norm
=
torch
.
norm
(
result
.
reshape
(
result
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
grad_norm
=
torch
.
norm
(
result
.
reshape
(
result
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
noise_norm
=
torch
.
norm
(
noise
.
reshape
(
noise
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
noise_norm
=
torch
.
norm
(
noise
.
reshape
(
noise
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
step_size
=
(
self
.
snr
*
noise_norm
/
grad_norm
)
**
2
*
2
step_size
=
(
self
.
config
.
snr
*
noise_norm
/
grad_norm
)
**
2
*
2
step_size
=
step_size
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
step_size
=
step_size
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
x_mean
=
x
+
step_size
[:,
None
,
None
,
None
]
*
result
x_mean
=
x
+
step_size
[:,
None
,
None
,
None
]
*
result
...
...
tests/test_modeling_utils.py
View file @
433cb3f8
...
@@ -33,8 +33,11 @@ from diffusers import (
...
@@ -33,8 +33,11 @@ from diffusers import (
GradTTSPipeline
,
GradTTSPipeline
,
GradTTSScheduler
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
NCSNpp
,
PNDMPipeline
,
PNDMPipeline
,
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetModel
,
...
@@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase):
)
)
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_score_sde_ve_pipeline
(
self
):
torch
.
manual_seed
(
0
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"fusing/ffhq_ncsnpp"
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
image
=
sde_ve
(
num_inference_steps
=
2
)
expected_image_sum
=
3382810112.0
expected_image_mean
=
1075.366455078125
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment