Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dc6d0286
Commit
dc6d0286
authored
Jun 26, 2022
by
Patrick von Platen
Browse files
add vp sampler
parent
d5c527a4
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
121 additions
and
2 deletions
+121
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-1
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+1
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_score_sde_ve.py
src/diffusers/pipelines/pipeline_score_sde_ve.py
+0
-0
src/diffusers/pipelines/pipeline_score_sde_vp.py
src/diffusers/pipelines/pipeline_score_sde_vp.py
+42
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_sde_vp.py
src/diffusers/schedulers/scheduling_sde_vp.py
+55
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+19
-0
No files found.
src/diffusers/__init__.py
View file @
dc6d0286
...
@@ -9,7 +9,7 @@ __version__ = "0.0.4"
...
@@ -9,7 +9,7 @@ __version__ = "0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
,
ScoreSdeVpPipeline
from
.schedulers
import
(
from
.schedulers
import
(
DDIMScheduler
,
DDIMScheduler
,
DDPMScheduler
,
DDPMScheduler
,
...
@@ -17,6 +17,7 @@ from .schedulers import (
...
@@ -17,6 +17,7 @@ from .schedulers import (
PNDMScheduler
,
PNDMScheduler
,
SchedulerMixin
,
SchedulerMixin
,
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVpScheduler
,
)
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
dc6d0286
...
@@ -766,7 +766,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -766,7 +766,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
continuous
=
continuous
,
continuous
=
continuous
,
)
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
self
.
act
=
act
=
get_act
(
nonlinearity
)
#
self.register_buffer('sigmas', torch.tensor(
utils.get_sigmas(config
)))
self
.
register_buffer
(
'sigmas'
,
torch
.
tensor
(
np
.
linspace
(
np
.
log
(
50
),
np
.
log
(
0.01
),
10
)))
self
.
nf
=
nf
self
.
nf
=
nf
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
...
...
src/diffusers/pipelines/__init__.py
View file @
dc6d0286
...
@@ -4,6 +4,7 @@ from .pipeline_ddim import DDIMPipeline
...
@@ -4,6 +4,7 @@ from .pipeline_ddim import DDIMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_score_sde_ve
import
ScoreSdeVePipeline
from
.pipeline_score_sde_ve
import
ScoreSdeVePipeline
from
.pipeline_score_sde_vp
import
ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
...
...
src/diffusers/pipelines/pipeline_score_sde_ve.py
100755 → 100644
View file @
dc6d0286
File mode changed from 100755 to 100644
src/diffusers/pipelines/pipeline_score_sde_vp.py
0 → 100644
View file @
dc6d0286
#!/usr/bin/env python3
import
torch
from
diffusers
import
DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class
ScoreSdeVpPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
def
__call__
(
self
,
num_inference_steps
=
1000
,
generator
=
None
):
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
channels
=
self
.
model
.
config
.
num_channels
shape
=
(
1
,
channels
,
img_size
,
img_size
)
beta_min
,
beta_max
=
0.1
,
20
model
=
self
.
model
.
to
(
device
)
x
=
torch
.
randn
(
*
shape
).
to
(
device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
self
.
scheduler
.
timesteps
):
t
=
t
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
sigma_t
=
t
*
(
num_inference_steps
-
1
)
with
torch
.
no_grad
():
result
=
model
(
x
,
sigma_t
)
log_mean_coeff
=
-
0.25
*
t
**
2
*
(
beta_max
-
beta_min
)
-
0.5
*
t
*
beta_min
std
=
torch
.
sqrt
(
1.
-
torch
.
exp
(
2.
*
log_mean_coeff
))
result
=
-
result
/
std
[:,
None
,
None
,
None
]
x
,
x_mean
=
self
.
scheduler
.
step_pred
(
result
,
x
,
t
)
x_mean
=
(
x_mean
+
1.
)
/
2.
return
x_mean
src/diffusers/schedulers/__init__.py
View file @
dc6d0286
...
@@ -22,3 +22,4 @@ from .scheduling_grad_tts import GradTTSScheduler
...
@@ -22,3 +22,4 @@ from .scheduling_grad_tts import GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_sde_ve
import
ScoreSdeVeScheduler
from
.scheduling_sde_ve
import
ScoreSdeVeScheduler
from
.scheduling_sde_vp
import
ScoreSdeVpScheduler
src/diffusers/schedulers/scheduling_sde_vp.py
0 → 100644
View file @
dc6d0286
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
class
ScoreSdeVpScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
beta_min
=
0.1
,
beta_max
=
20
,
sampling_eps
=
1e-3
,
tensor_format
=
"np"
):
super
().
__init__
()
self
.
register_to_config
(
beta_min
=
beta_min
,
beta_max
=
beta_max
,
sampling_eps
=
sampling_eps
,
)
self
.
sigmas
=
None
self
.
discrete_sigmas
=
None
self
.
timesteps
=
None
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
timesteps
=
torch
.
linspace
(
1
,
self
.
config
.
sampling_eps
,
num_inference_steps
)
def
step_pred
(
self
,
result
,
x
,
t
):
dt
=
-
1.
/
len
(
self
.
timesteps
)
z
=
torch
.
randn_like
(
x
)
beta_t
=
self
.
beta_min
+
t
*
(
self
.
beta_max
-
self
.
beta_min
)
drift
=
-
0.5
*
beta_t
[:,
None
,
None
,
None
]
*
x
diffusion
=
torch
.
sqrt
(
beta_t
)
drift
=
drift
-
diffusion
[:,
None
,
None
,
None
]
**
2
*
result
x_mean
=
x
+
drift
*
dt
x
=
x_mean
+
diffusion
[:,
None
,
None
,
None
]
*
np
.
sqrt
(
-
dt
)
*
z
return
x
,
x_mean
tests/test_modeling_utils.py
View file @
dc6d0286
...
@@ -38,6 +38,8 @@ from diffusers import (
...
@@ -38,6 +38,8 @@ from diffusers import (
PNDMScheduler
,
PNDMScheduler
,
ScoreSdeVePipeline
,
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetModel
,
...
@@ -741,6 +743,23 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -741,6 +743,23 @@ class PipelineTesterMixin(unittest.TestCase):
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
@
slow
def
test_score_sde_vp_pipeline
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"/home/patrick/cifar10-ddpmpp-vp"
)
scheduler
=
ScoreSdeVpScheduler
()
sde_vp
=
ScoreSdeVpPipeline
(
model
=
model
,
scheduler
=
scheduler
)
torch
.
manual_seed
(
0
)
image
=
sde_vp
(
num_inference_steps
=
10
)
expected_image_sum
=
4183.2012
expected_image_mean
=
1.3617
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment