Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
de810814
Commit
de810814
authored
Jun 25, 2022
by
Patrick von Platen
Browse files
finish first version sde ve
parent
bc2d586d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
180 additions
and
18 deletions
+180
-18
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+8
-17
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+3
-0
src/diffusers/pipelines/pipeline_score_sde.py
src/diffusers/pipelines/pipeline_score_sde.py
+94
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_ve_sde.py
src/diffusers/schedulers/scheduling_ve_sde.py
+73
-0
No files found.
src/diffusers/__init__.py
View file @
de810814
...
@@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
...
@@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
,
VeSdeScheduler
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
de810814
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
# helpers functions
# helpers functions
from
..modeling_utils
import
ModelMixin
from
..configuration_utils
import
ConfigMixin
import
functools
import
functools
import
math
import
math
import
string
import
string
...
@@ -28,16 +24,15 @@ import torch
...
@@ -28,16 +24,15 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
return
upfirdn2d_native
(
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
]
)
def
upfirdn2d_native
(
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
channel
,
in_h
,
in_w
=
input
.
shape
_
,
channel
,
in_h
,
in_w
=
input
.
shape
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
input
=
input
.
reshape
(
-
1
,
in_h
,
in_w
,
1
)
...
@@ -48,9 +43,7 @@ def upfirdn2d_native(
...
@@ -48,9 +43,7 @@ def upfirdn2d_native(
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
F
.
pad
(
out
,
[
0
,
0
,
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
])
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
out
.
view
(
-
1
,
in_h
*
up_y
,
in_w
*
up_x
,
minor
)
out
=
F
.
pad
(
out
=
F
.
pad
(
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
,
[
0
,
0
,
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)]
)
out
=
out
[
out
=
out
[
:,
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
1
]
-
max
(
-
pad_y1
,
0
),
...
@@ -59,9 +52,7 @@ def upfirdn2d_native(
...
@@ -59,9 +52,7 @@ def upfirdn2d_native(
]
]
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
permute
(
0
,
3
,
1
,
2
)
out
=
out
.
reshape
(
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
[
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
]
)
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
out
=
out
.
reshape
(
...
@@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
...
@@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
def
_einsum
(
a
,
b
,
c
,
x
,
y
):
einsum_str
=
'
{},{}->{}
'
.
format
(
''
.
join
(
a
),
''
.
join
(
b
),
''
.
join
(
c
))
einsum_str
=
"
{},{}->{}
"
.
format
(
""
.
join
(
a
),
""
.
join
(
b
),
""
.
join
(
c
))
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
return
torch
.
einsum
(
einsum_str
,
x
,
y
)
...
...
src/diffusers/pipelines/__init__.py
View file @
de810814
...
@@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
...
@@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_pndm
import
PNDMPipeline
# from .pipeline_score_sde import NCSNppPipeline
if
is_transformers_available
():
if
is_transformers_available
():
from
.pipeline_glide
import
GlidePipeline
from
.pipeline_glide
import
GlidePipeline
from
.pipeline_latent_diffusion
import
LatentDiffusionPipeline
from
.pipeline_latent_diffusion
import
LatentDiffusionPipeline
...
...
src/diffusers/pipelines/pipeline_score_sde.py
0 → 100755
View file @
de810814
#!/usr/bin/env python3
import
numpy
as
np
import
torch
import
PIL
from
diffusers
import
DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
manual_seed
(
0
)
class
NCSNppPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
def
__call__
(
self
,
generator
=
None
):
N
=
self
.
scheduler
.
config
.
N
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
channels
=
self
.
model
.
config
.
num_channels
shape
=
(
1
,
channels
,
img_size
,
img_size
)
model
=
torch
.
nn
.
DataParallel
(
self
.
model
.
to
(
device
))
centered
=
False
n_steps
=
1
# Initial sample
x
=
torch
.
randn
(
*
shape
)
*
self
.
scheduler
.
config
.
sigma_max
x
=
x
.
to
(
device
)
for
i
in
range
(
N
):
sigma_t
=
self
.
scheduler
.
get_sigma_t
(
i
)
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
for
_
in
range
(
n_steps
):
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
x
=
self
.
scheduler
.
step_correct
(
result
,
x
)
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
i
)
x
=
x_mean
if
centered
:
x
=
(
x
+
1.0
)
/
2.0
return
x
pipeline
=
NCSNppPipeline
.
from_pretrained
(
"/home/patrick/ffhq_ncsnpp"
)
x
=
pipeline
()
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
x_sum
=
3382810112.0
x_mean
=
1075.366455078125
def
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
):
assert
(
x
.
abs
().
sum
()
-
x_sum
).
abs
().
cpu
().
item
()
<
1e-2
,
f
"sum wrong
{
x
.
abs
().
sum
()
}
"
assert
(
x
.
abs
().
mean
()
-
x_mean
).
abs
().
cpu
().
item
()
<
1e-4
,
f
"mean wrong
{
x
.
abs
().
mean
()
}
"
check_x_sum_x_mean
(
x
,
x_sum
,
x_mean
)
def
save_image
(
x
):
image_processed
=
np
.
clip
(
x
.
permute
(
0
,
2
,
3
,
1
).
cpu
().
numpy
()
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
.
save
(
"../images/hey.png"
)
# save_image(x)
src/diffusers/schedulers/__init__.py
View file @
de810814
...
@@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
...
@@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_ve_sde
import
VeSdeScheduler
src/diffusers/schedulers/scheduling_ve_sde.py
0 → 100644
View file @
de810814
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
class
VeSdeScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
snr
=
0.15
,
sigma_min
=
0.01
,
sigma_max
=
1348
,
N
=
2
,
sampling_eps
=
1e-5
,
tensor_format
=
"np"
):
super
().
__init__
()
self
.
register_to_config
(
snr
=
snr
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
N
=
N
,
sampling_eps
=
sampling_eps
,
)
# (PVP) - clean up with .config.
self
.
sigma_min
=
sigma_min
self
.
sigma_max
=
sigma_max
self
.
snr
=
snr
self
.
N
=
N
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
self
.
sigma_min
),
np
.
log
(
self
.
sigma_max
),
N
))
self
.
timesteps
=
torch
.
linspace
(
1
,
sampling_eps
,
N
)
def
get_sigma_t
(
self
,
t
):
return
self
.
sigma_min
*
(
self
.
sigma_max
/
self
.
sigma_min
)
**
self
.
timesteps
[
t
]
def
step_pred
(
self
,
result
,
x
,
t
):
t
=
self
.
timesteps
[
t
]
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
timestep
=
(
t
*
(
self
.
N
-
1
)).
long
()
sigma
=
self
.
discrete_sigmas
.
to
(
t
.
device
)[
timestep
]
adjacent_sigma
=
torch
.
where
(
timestep
==
0
,
torch
.
zeros_like
(
t
),
self
.
discrete_sigmas
[
timestep
-
1
].
to
(
t
.
device
)
)
f
=
torch
.
zeros_like
(
x
)
G
=
torch
.
sqrt
(
sigma
**
2
-
adjacent_sigma
**
2
)
f
=
f
-
G
[:,
None
,
None
,
None
]
**
2
*
result
z
=
torch
.
randn_like
(
x
)
x_mean
=
x
-
f
x
=
x_mean
+
G
[:,
None
,
None
,
None
]
*
z
return
x
,
x_mean
def
step_correct
(
self
,
result
,
x
):
noise
=
torch
.
randn_like
(
x
)
grad_norm
=
torch
.
norm
(
result
.
reshape
(
result
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
noise_norm
=
torch
.
norm
(
noise
.
reshape
(
noise
.
shape
[
0
],
-
1
),
dim
=-
1
).
mean
()
step_size
=
(
self
.
snr
*
noise_norm
/
grad_norm
)
**
2
*
2
step_size
=
step_size
*
torch
.
ones
(
x
.
shape
[
0
],
device
=
x
.
device
)
x_mean
=
x
+
step_size
[:,
None
,
None
,
None
]
*
result
x
=
x_mean
+
torch
.
sqrt
(
step_size
*
2
)[:,
None
,
None
,
None
]
*
noise
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment