Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
de810814
Commit
de810814
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
finish first version sde ve
parent
bc2d586d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
180 additions
and
18 deletions
+180
-18
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+8
-17
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+3
-0
src/diffusers/pipelines/pipeline_score_sde.py
src/diffusers/pipelines/pipeline_score_sde.py
+94
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_ve_sde.py
src/diffusers/schedulers/scheduling_ve_sde.py
+73
-0
No files found.
src/diffusers/__init__.py
View file @
de810814
...
@@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
...
@@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
,
VeSdeScheduler
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
de810814
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
# helpers functions
# helpers functions
from
..modeling_utils
import
ModelMixin
from
..configuration_utils
import
ConfigMixin
import
functools
import
functools
import
math
import
math
import
string
import
string
...
@@ -28,16 +24,15 @@ import torch
...
@@ -28,16 +24,15 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
return
upfirdn2d_native
(
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
)
def
upfirdn2d_native
(
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
...
@@ -48,9 +43,7 @@ def upfirdn2d_native(
...
@@ -48,9 +43,7 @@ def upfirdn2d_native(
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)]
)
out
=
out
[
out
=
out
[
:,
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
...
@@ -59,9 +52,7 @@ def upfirdn2d_native(
...
@@ -59,9 +52,7 @@ def upfirdn2d_native(
]
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
(
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
[
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
]
)
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
out
=
out
.
reshape
(
...
@@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
...
@@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
'
{},{}->{}
'
.
format
(
''
.
join
(
a
),
''
.
join
(
b
),
''
.
join
(
c
))
einsum_str
=
"
{},{}->{}
"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
...
src/diffusers/pipelines/__init__.py
View file @
de810814
...
@@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
...
@@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_pndm
import
PNDMPipeline
# from .pipeline_score_sde import NCSNppPipeline
if
is_transformers_available
():
if
is_transformers_available
():
from
.pipeline_glide
import
GlidePipeline
from
.pipeline_glide
import
GlidePipeline
from
.pipeline_latent_diffusion
import
LatentDiffusionPipeline
from
.pipeline_latent_diffusion
import
LatentDiffusionPipeline
...
...
src/diffusers/pipelines/pipeline_score_sde.py
0 → 100755
View file @
de810814
#!/usr/bin/env python3
import
numpy
as
np
import
torch
import
PIL
from
diffusers
import
DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
class
NCSNppPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
def
__call__
(
self
,
generator
=
None
):
N
=
self
.
scheduler
.
config
.
N
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
channels
=
self
.
model
.
config
.
num_channels
shape
=
(
1
,
channels
,
img_size
,
img_size
)
model
=
torch
.
nn
.
DataParallel
(
self
.
model
.
to
(
device
))
centered
=
False
n_steps
=
1
# Initial sample
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
x
.
to
(
device
)
for
i
in
range
(
N
):
sigma_t
=
self
.
scheduler
.
get_sigma_t
(
i
)
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
for
_
in
range
(
n_steps
):
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
x
=
self
.
scheduler
.
step_correct
(
result
,
x
)
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
i
)
x
=
x_mean
if
centered
:
x
=
(
x
+
1.0
)
/
2.0
return
x
pipeline
=
NCSNppPipeline
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
)
x
=
pipeline
()
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
x_sum
=
3382810112.0
x_mean
=
1075.366455078125
def
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
):
assert
(
x
.
abs
().
sum
()
-
x_sum
).
abs
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
def
save_image
(
x
):
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
.
save
(
"../images/hey.png"
)
# save_image(x)
src/diffusers/schedulers/__init__.py
View file @
de810814
...
@@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
...
@@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_ve_sde
import
VeSdeScheduler
src/diffusers/schedulers/scheduling_ve_sde.py
0 → 100644
View file @
de810814
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
class
VeSdeScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
snr
=
0.15
,
sigma_min
=
0.01
,
sigma_max
=
1348
,
N
=
2
,
sampling_eps
=
1e-5
,
tensor_format
=
"np"
):
super
().
__init__
()
self
.
register_to_config
(
snr
=
snr
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
,
sampling_eps
=
sampling_eps
,
)
# (PVP) - clean up with .config.
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
self
.
snr
=
snr
self
.
N
=
N
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
sigma_min
),
np
.
log
(
self
.
sigma_max
),
N
))
self
.
timesteps
=
torch
.
linspace
(
1
,
sampling_eps
,
N
)
def
get_sigma_t
(
self
,
t
):
return
self
.
sigma_min
*
(
self
.
sigma_max
/
self
.
sigma_min
)
**
self
.
timesteps
[
t
]
def
step_pred
(
self
,
result
,
x
,
t
):
t
=
self
.
timesteps
[
t
]
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
timestep
=
(
t
*
(
self
.
N
-
1
)).
long
()
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
adjacent_sigma
=
torch
.
where
(
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
.
device
)
)
f
=
torch
.
zeros_like
(
x
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
f
=
f
-
G
[:,
None
,
None
,
None
]
**
2
*
result
z
=
torch
.
randn_like
(
x
)
x_mean
=
x
-
f
x
=
x_mean
+
G
[:,
None
,
None
,
None
]
*
z
return
x
,
x_mean
def
step_correct
(
self
,
result
,
x
):
noise
=
torch
.
randn_like
(
x
)
grad_norm
=
torch
.
norm
(
result
.
reshape
(
result
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
noise_norm
=
torch
.
norm
(
noise
.
reshape
(
noise
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
step_size
=
(
self
.
snr
*
noise_norm
/
grad_norm
)
**
2
*
2
step_size
=
step_size
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
x_mean
=
x
+
step_size
[:,
None
,
None
,
None
]
*
result
x
=
x_mean
+
torch
.
sqrt
(
step_size
*
2
)[:,
None
,
None
,
None
]
*
noise
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment