Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
433cb3f8
Commit
433cb3f8
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
clean up sde ve more
parent
de810814
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
120 additions
and
58 deletions
+120
-58
README.md
README.md
+24
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-2
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+2
-1
src/diffusers/pipelines/pipeline_score_sde_ve.py
src/diffusers/pipelines/pipeline_score_sde_ve.py
+38
-38
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+26
-16
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-0
No files found.
README.md
View file @
433cb3f8
...
@@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0])
...
@@ -226,6 +226,30 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil
.
save
(
"test.png"
)
image_pil
.
save
(
"test.png"
)
```
```
#### **Example 1024x1024 image generation with SDE VE**
See
[
paper
](
https://arxiv.org/abs/2011.13456
)
for more information on SDE VE.
```
python
from
diffusers
import
DiffusionPipeline
import
torch
import
PIL.Image
torch
.
manual_seed
(
32
)
score_sde_sv
=
DiffusionPipeline
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
)
# Note this might take up to 3 minutes on a GPU
image
=
score_sde_sv
(
num_inference_steps
=
2000
)
image
=
image
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
image
=
np
.
clip
(
image
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image
[
0
])
# save image
image_pil
.
save
(
"test.png"
)
```
#### **Text to Image generation with Latent Diffusion**
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
...
...
src/diffusers/__init__.py
View file @
433cb3f8
...
@@ -9,8 +9,15 @@ __version__ = "0.0.4"
...
@@ -9,8 +9,15 @@ __version__ = "0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
,
VeSdeScheduler
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
)
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/pipelines/__init__.py
View file @
433cb3f8
...
@@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline
...
@@ -3,9 +3,10 @@ from .pipeline_bddm import BDDMPipeline
from
.pipeline_ddim
import
DDIMPipeline
from
.pipeline_ddim
import
DDIMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_score_sde_ve
import
ScoreSdeVePipeline
# from .pipeline_score_sde import
NCSNpp
Pipeline
# from .pipeline_score_sde import
ScoreSdeVe
Pipeline
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/pipelines/pipeline_score_sde.py
→
src/diffusers/pipelines/pipeline_score_sde
_ve
.py
View file @
433cb3f8
...
@@ -6,51 +6,44 @@ import PIL
...
@@ -6,51 +6,44 @@ import PIL
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
class
NCSNpp
Pipeline
(
DiffusionPipeline
):
class
ScoreSdeVe
Pipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
def
__call__
(
self
,
generator
=
None
):
def
__call__
(
self
,
num_inference_steps
=
2000
,
generator
=
None
):
N
=
self
.
scheduler
.
config
.
N
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
img_size
=
self
.
model
.
config
.
image_size
channels
=
self
.
model
.
config
.
num_channels
channels
=
self
.
model
.
config
.
num_channels
shape
=
(
1
,
channels
,
img_size
,
img_size
)
shape
=
(
1
,
channels
,
img_size
,
img_size
)
model
=
torch
.
nn
.
DataParallel
(
self
.
model
.
to
(
device
)
)
model
=
self
.
model
.
to
(
device
)
centered
=
False
centered
=
False
n_steps
=
1
n_steps
=
1
# Initial sample
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
x
.
to
(
device
)
x
=
x
.
to
(
device
)
for
i
in
range
(
N
):
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
sigma_t
=
self
.
scheduler
.
get_sigma_t
(
i
)
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
self
.
scheduler
.
set_sigmas
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
self
.
scheduler
.
timesteps
):
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
for
_
in
range
(
n_steps
):
for
_
in
range
(
n_steps
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
result
=
self
.
model
(
x
,
sigma_t
)
x
=
self
.
scheduler
.
step_correct
(
result
,
x
)
x
=
self
.
scheduler
.
step_correct
(
result
,
x
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
result
=
model
(
x
,
sigma_t
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
i
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
t
)
x
=
x_mean
x
=
x_mean
...
@@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline):
...
@@ -60,9 +53,16 @@ class NCSNppPipeline(DiffusionPipeline):
return
x
return
x
pipeline
=
NCSNppPipeline
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
)
# from configs.ve import ffhq_ncsnpp_continuous as configs
x
=
pipeline
()
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
# x = pipeline(num_inference_steps=2)
# for 5 cifar10
# for 5 cifar10
# x_sum = 106071.9922
# x_sum = 106071.9922
...
@@ -73,22 +73,22 @@ x = pipeline()
...
@@ -73,22 +73,22 @@ x = pipeline()
# x_mean = 0.1504
# x_mean = 0.1504
# for N=2 for 1024
# for N=2 for 1024
x_sum
=
3382810112.0
#
x_sum = 3382810112.0
x_mean
=
1075.366455078125
#
x_mean = 1075.366455078125
#
#
def
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
):
#
def check_x_sum_x_mean(x, x_sum, x_mean):
assert
(
x
.
abs
().
sum
()
-
x_sum
).
abs
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
#
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
#
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#
#
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
#
check_x_sum_x_mean(x, x_sum, x_mean)
#
#
def
save_image
(
x
):
#
def save_image(x):
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
#
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
#
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil
.
save
(
"../images/hey.png"
)
#
image_pil.save("../images/hey.png")
#
#
# save_image(x)
# save_image(x)
src/diffusers/schedulers/__init__.py
View file @
433cb3f8
...
@@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
...
@@ -21,4 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_
ve_sd
e
import
V
eSdeScheduler
from
.scheduling_
sde_v
e
import
Scor
eSde
Ve
Scheduler
src/diffusers/schedulers/scheduling_
ve_sd
e.py
→
src/diffusers/schedulers/scheduling_
sde_v
e.py
View file @
433cb3f8
# Copyright 2022
UC Berkely Team
and The HuggingFace Team. All rights reserved.
# Copyright 2022
Google Brain
and The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin
...
@@ -21,7 +23,7 @@ from ..configuration_utils import ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
class
V
eSdeScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
Scor
eSde
Ve
Scheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
snr
=
0.15
,
sigma_min
=
0.01
,
sigma_max
=
1348
,
N
=
2
,
sampling_eps
=
1e-5
,
tensor_format
=
"np"
):
def
__init__
(
self
,
snr
=
0.15
,
sigma_min
=
0.01
,
sigma_max
=
1348
,
N
=
2
,
sampling_eps
=
1e-5
,
tensor_format
=
"np"
):
super
().
__init__
()
super
().
__init__
()
self
.
register_to_config
(
self
.
register_to_config
(
...
@@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -31,24 +33,32 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
N
=
N
,
N
=
N
,
sampling_eps
=
sampling_eps
,
sampling_eps
=
sampling_eps
,
)
)
# (PVP) - clean up with .config.
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
self
.
snr
=
snr
self
.
N
=
N
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
sigma_min
),
np
.
log
(
self
.
sigma_max
),
N
))
self
.
timesteps
=
torch
.
linspace
(
1
,
sampling_eps
,
N
)
def
get_sigma_t
(
self
,
t
):
self
.
sigmas
=
None
return
self
.
sigma_min
*
(
self
.
sigma_max
/
self
.
sigma_min
)
**
self
.
timesteps
[
t
]
self
.
discrete_sigmas
=
None
self
.
timesteps
=
None
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
timesteps
=
torch
.
linspace
(
1
,
self
.
config
.
sampling_eps
,
num_inference_steps
)
def
set_sigmas
(
self
,
num_inference_steps
):
if
self
.
timesteps
is
None
:
self
.
set_timesteps
(
num_inference_steps
)
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
config
.
sigma_min
),
np
.
log
(
self
.
config
.
sigma_max
),
num_inference_steps
)
)
self
.
sigmas
=
torch
.
tensor
(
[
self
.
config
.
sigma_min
*
(
self
.
config
.
sigma_max
/
self
.
sigma_min
)
**
t
for
t
in
self
.
timesteps
]
)
def
step_pred
(
self
,
result
,
x
,
t
):
def
step_pred
(
self
,
result
,
x
,
t
):
t
=
self
.
timesteps
[
t
]
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
t
=
t
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
timestep
=
(
t
*
(
2
-
1
)).
long
()
timestep
=
(
t
*
(
self
.
N
-
1
)).
long
()
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
adjacent_sigma
=
torch
.
where
(
adjacent_sigma
=
torch
.
where
(
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
.
device
)
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
imestep
.
device
)
)
)
f
=
torch
.
zeros_like
(
x
)
f
=
torch
.
zeros_like
(
x
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
...
@@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -64,7 +74,7 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
noise
=
torch
.
randn_like
(
x
)
noise
=
torch
.
randn_like
(
x
)
grad_norm
=
torch
.
norm
(
result
.
reshape
(
result
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
grad_norm
=
torch
.
norm
(
result
.
reshape
(
result
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
noise_norm
=
torch
.
norm
(
noise
.
reshape
(
noise
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
noise_norm
=
torch
.
norm
(
noise
.
reshape
(
noise
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
step_size
=
(
self
.
snr
*
noise_norm
/
grad_norm
)
**
2
*
2
step_size
=
(
self
.
config
.
snr
*
noise_norm
/
grad_norm
)
**
2
*
2
step_size
=
step_size
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
step_size
=
step_size
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
x_mean
=
x
+
step_size
[:,
None
,
None
,
None
]
*
result
x_mean
=
x
+
step_size
[:,
None
,
None
,
None
]
*
result
...
...
tests/test_modeling_utils.py
View file @
433cb3f8
...
@@ -33,8 +33,11 @@ from diffusers import (
...
@@ -33,8 +33,11 @@ from diffusers import (
GradTTSPipeline
,
GradTTSPipeline
,
GradTTSScheduler
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
NCSNpp
,
PNDMPipeline
,
PNDMPipeline
,
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetModel
,
...
@@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -721,6 +724,23 @@ class PipelineTesterMixin(unittest.TestCase):
)
)
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
mel_spec
[
0
,
:
3
,
:
3
].
cpu
().
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_score_sde_ve_pipeline
(
self
):
torch
.
manual_seed
(
0
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"fusing/ffhq_ncsnpp"
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
image
=
sde_ve
(
num_inference_steps
=
2
)
expected_image_sum
=
3382810112.0
expected_image_mean
=
1075.366455078125
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment