Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dc6324d4
Commit
dc6324d4
authored
Jun 09, 2022
by
anton-l
Browse files
end-to-end glide pipeline with DDIM scheduler for upscaling
parent
ff89f808
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
238 additions
and
87 deletions
+238
-87
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+5
-5
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+86
-28
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+0
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-0
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+135
-38
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+1
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/glide_ddim.py
src/diffusers/schedulers/glide_ddim.py
+9
-15
No files found.
models/vision/glide/convert_weights.py
View file @
dc6324d4
import
torch
from
torch
import
nn
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
CLIPTextModel
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
modeling_glide
import
GLIDE
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
...
...
@@ -76,7 +76,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict
=
torch
.
load
(
"upsample.pt"
,
map_location
=
"cpu"
)
ups_
state_dict
=
torch
.
load
(
"upsample.pt"
,
map_location
=
"cpu"
)
superres_model
=
GLIDESuperResUNetModel
(
in_channels
=
6
,
...
...
@@ -93,12 +93,12 @@ superres_model = GLIDESuperResUNetModel(
resblock_updown
=
True
,
)
superres_model
.
load_state_dict
(
state_dict
)
superres_model
.
load_state_dict
(
ups_
state_dict
,
strict
=
False
)
upscale_scheduler
=
ClassifierFreeGuidance
Scheduler
(
timesteps
=
1000
,
beta_schedule
=
"
squaredcos_cap_v2
"
)
upscale_scheduler
=
GlideDDIM
Scheduler
(
timesteps
=
1000
,
beta_schedule
=
"
linear
"
)
glide
=
GLIDE
(
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
scheduler
)
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
upscale_
scheduler
)
glide
.
save_pretrained
(
"./glide-base"
)
...
...
models/vision/glide/modeling_glide.py
View file @
dc6324d4
...
...
@@ -18,7 +18,7 @@ import numpy as np
import
torch
import
tqdm
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNetModel
from
transformers
import
GPT2Tokenizer
...
...
@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
:
GLIDETextToImageUNetModel
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_
unet
:
GLIDETextToImageUNetModel
,
text_
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_noise_scheduler
:
GlideDDIMScheduler
):
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
text_unet
=
text_unet
,
text_noise_scheduler
=
text_noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_noise_scheduler
=
upscale_noise_scheduler
)
def
q_posterior_mean_variance
(
self
,
x_start
,
x_t
,
t
):
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
"""
Compute the mean and variance of the diffusion posterior:
...
...
@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
"""
assert
x_start
.
shape
==
x_t
.
shape
posterior_mean
=
(
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
_extract_into_tensor
(
scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
posterior_mean
.
shape
[
0
]
...
...
@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
transformer_out
,
clip_denoised
=
Tru
e
,
model_kwargs
=
Non
e
):
def
p_mean_variance
(
self
,
model
,
scheduler
,
x
,
t
,
transformer_out
=
None
,
low_res
=
None
,
clip_denoised
=
True
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
...
...
@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if
model_kwargs
is
None
:
model_kwargs
=
{}
B
,
C
=
x
.
shape
[:
2
]
assert
t
.
shape
==
(
B
,)
if
transformer_out
is
None
:
# super-res model
model_output
=
model
(
x
,
t
,
low_res
)
else
:
# text2image model
model_output
=
model
(
x
,
t
,
transformer_out
)
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
C
,
dim
=
1
)
min_log
=
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
self
.
noise_
scheduler
.
betas
),
t
,
x
.
shape
)
min_log
=
_extract_into_tensor
(
scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
scheduler
.
betas
),
t
,
x
.
shape
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_variance
=
torch
.
exp
(
model_log_variance
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
scheduler
,
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
if
clip_denoised
:
pred_xstart
=
pred_xstart
.
clamp
(
-
1
,
1
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
scheduler
,
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
assert
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
return
model_mean
,
model_variance
,
model_log_variance
,
pred_xstart
def
_predict_xstart_from_eps
(
self
,
x_t
,
t
,
eps
):
def
_predict_xstart_from_eps
(
self
,
scheduler
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
return
(
_extract_into_tensor
(
self
.
noise_
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
self
.
noise_
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
_extract_into_tensor
(
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
)
def
_predict_eps_from_xstart
(
self
,
scheduler
,
x_t
,
t
,
pred_xstart
):
return
(
_extract_into_tensor
(
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
pred_xstart
)
/
_extract_into_tensor
(
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
):
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
text_
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
self
.
upscale_unet
.
to
(
torch_device
)
# Create a classifier-free guidance sampling function
guidance_scale
=
3.0
def
model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
def
text_
model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
model_out
=
self
.
text_
unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
...
...
@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
batch_size
=
2
# second image is empty for classifier-free guidance
image
=
self
.
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
image
=
self
.
text_
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
text_
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
)
# 2. Encode tokens
...
...
@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
num_timesteps
=
len
(
self
.
noise_scheduler
)
# 3. Run the text2image generation step
num_timesteps
=
len
(
self
.
text_noise_scheduler
)
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
model_fn
,
image
,
t
,
transformer_out
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
text_model_fn
,
self
.
text_noise_scheduler
,
image
,
t
,
transformer_out
=
transformer_out
)
noise
=
self
.
text_noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
# 4. Run the upscaling step
batch_size
=
1
image
=
image
[:
1
]
low_res
=
((
image
+
1
)
*
127.5
).
round
()
/
127.5
-
1
eta
=
0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp
=
0.997
image
=
self
.
upscale_noise_scheduler
.
sample_noise
(
(
batch_size
,
3
,
256
,
256
),
device
=
torch_device
,
generator
=
generator
)
*
upsample_temp
num_timesteps
=
len
(
self
.
upscale_noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
upscale_noise_scheduler
))),
total
=
len
(
self
.
upscale_noise_scheduler
)):
# i) define coefficients for time step t
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_coeff
=
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
upscale_noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
model_output
=
self
.
upscale_unet
(
image
,
time_input
,
low_res
)
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clipped_image_coeff
*
image
-
clipped_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
self
.
upscale_noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
image
=
image
[
0
].
permute
(
1
,
2
,
0
)
return
image
models/vision/glide/run_glide.py
View file @
dc6324d4
...
...
@@ -9,7 +9,6 @@ matplotlib.rcParams['interactive'] = True
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
img
=
pipeline
(
"a pencil sketch of a corgi"
,
generator
)
...
...
src/diffusers/__init__.py
View file @
dc6324d4
...
...
@@ -13,3 +13,4 @@ from .models.vqvae import VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.glide_ddim
import
GlideDDIMScheduler
src/diffusers/models/unet_glide.py
View file @
dc6324d4
...
...
@@ -419,11 +419,11 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
def
__init__
(
self
,
in_channels
,
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
)
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
...
...
@@ -438,24 +438,6 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
transformer_dim
=
None
,
):
super
().
__init__
()
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
...
...
@@ -632,7 +614,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
,
y
=
None
):
def
forward
(
self
,
x
,
timesteps
):
"""
Apply the model to an input batch.
...
...
@@ -641,17 +623,10 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
if
self
.
num_classes
is
not
None
:
assert
y
.
shape
==
(
x
.
shape
[
0
],)
emb
=
emb
+
self
.
label_emb
(
y
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
...
...
@@ -671,10 +646,66 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__init__
(
self
,
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
transformer_dim
=
512
):
super
().
__init__
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
)
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
)
self
.
transformer_proj
=
nn
.
Linear
(
kwargs
[
"
transformer_dim
"
]
,
self
.
model_channels
*
4
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
...
...
@@ -705,11 +736,77 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__init__
(
self
,
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
):
super
().
__init__
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
)
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
)
def
forward
(
self
,
x
,
timesteps
,
low_res
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
,
low_res
=
None
):
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
return
super
().
forward
(
x
,
timesteps
,
**
kwargs
)
\ No newline at end of file
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
return
self
.
out
(
h
)
\ No newline at end of file
src/diffusers/pipeline_utils.py
View file @
dc6324d4
...
...
@@ -39,6 +39,7 @@ LOADABLE_CLASSES = {
"CLIPTextModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
# TODO (Anton): move to transformers
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
"GlideDDIMScheduler"
:
[
"save_config"
,
"from_config"
],
},
"transformers"
:
{
"GPT2Tokenizer"
:
[
"save_pretrained"
,
"from_pretrained"
],
...
...
src/diffusers/schedulers/__init__.py
View file @
dc6324d4
...
...
@@ -18,3 +18,4 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.glide_ddim
import
GlideDDIMScheduler
src/diffusers/schedulers/ddim.py
→
src/diffusers/schedulers/
glide_
ddim.py
View file @
dc6324d4
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
math
import
numpy
as
np
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
...
...
@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
class
G
aussian
DD
P
MScheduler
(
nn
.
Module
,
ConfigMixin
):
class
G
lide
DD
I
MScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_
small"
,
variance_type
=
"fixed_
large"
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
variance_type
=
variance_type
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale
=
1000
/
self
.
num_timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment