Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
74d2da99
"vscode:/vscode.git/clone" did not exist on "269983dbcd221acb64d03817fb25f0e27788bbaf"
Commit
74d2da99
authored
Jun 09, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
397b31c8
c6a33e3d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
324 additions
and
93 deletions
+324
-93
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+34
-6
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+86
-28
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+13
-12
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-1
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+176
-29
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+2
-1
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/glide_ddim.py
src/diffusers/schedulers/glide_ddim.py
+9
-15
No files found.
models/vision/glide/convert_weights.py
View file @
74d2da99
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
UNetGLIDE
Model
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
CLIPTextModel
,
GLIDETextToImageUNetModel
,
GLIDESuperResUNet
Model
from
modeling_glide
import
GLIDE
from
modeling_glide
import
GLIDE
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
...
@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
...
@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
### Convert the UNet
### Convert the
Text-to-Image
UNet
unet
_model
=
UNet
GLIDEModel
(
text2im
_model
=
GLIDE
TextToImageUNet
Model
(
in_channels
=
3
,
in_channels
=
3
,
model_channels
=
192
,
model_channels
=
192
,
out_channels
=
6
,
out_channels
=
6
,
...
@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
...
@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
transformer_dim
=
512
,
transformer_dim
=
512
,
)
)
unet
_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
text2im
_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
text_
scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
glide
=
GLIDE
(
unet
=
unet_model
,
noise_scheduler
=
scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
)
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict
=
torch
.
load
(
"upsample.pt"
,
map_location
=
"cpu"
)
superres_model
=
GLIDESuperResUNetModel
(
in_channels
=
6
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
2
,
attention_resolutions
=
(
8
,
16
,
32
),
dropout
=
0.1
,
channel_mult
=
(
1
,
1
,
2
,
2
,
4
,
4
),
num_heads
=
1
,
num_head_channels
=
64
,
num_heads_upsample
=
1
,
use_scale_shift_norm
=
True
,
resblock_updown
=
True
,
)
superres_model
.
load_state_dict
(
ups_state_dict
,
strict
=
False
)
upscale_scheduler
=
GlideDDIMScheduler
(
timesteps
=
1000
,
beta_schedule
=
"linear"
)
glide
=
GLIDE
(
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
upscale_scheduler
)
glide
.
save_pretrained
(
"./glide-base"
)
glide
.
save_pretrained
(
"./glide-base"
)
models/vision/glide/modeling_glide.py
View file @
74d2da99
...
@@ -18,7 +18,7 @@ import numpy as np
...
@@ -18,7 +18,7 @@ import numpy as np
import
torch
import
torch
import
tqdm
import
tqdm
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
UNet
GLIDEModel
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
transformers
import
GPT2Tokenizer
from
transformers
import
GPT2Tokenizer
...
@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
...
@@ -41,17 +41,20 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class
GLIDE
(
DiffusionPipeline
):
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
def
__init__
(
self
,
self
,
unet
:
UNet
GLIDEModel
,
text_
unet
:
GLIDE
TextToImageUNet
Model
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_noise_scheduler
:
GlideDDIMScheduler
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
text_unet
=
text_unet
,
text_noise_scheduler
=
text_noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_noise_scheduler
=
upscale_noise_scheduler
)
)
def
q_posterior_mean_variance
(
self
,
x_start
,
x_t
,
t
):
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
"""
"""
Compute the mean and variance of the diffusion posterior:
Compute the mean and variance of the diffusion posterior:
...
@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
...
@@ -60,12 +63,12 @@ class GLIDE(DiffusionPipeline):
"""
"""
assert
x_start
.
shape
==
x_t
.
shape
assert
x_start
.
shape
==
x_t
.
shape
posterior_mean
=
(
posterior_mean
=
(
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
_extract_into_tensor
(
scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
)
posterior_variance
=
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
posterior_log_variance_clipped
=
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
)
assert
(
assert
(
posterior_mean
.
shape
[
0
]
posterior_mean
.
shape
[
0
]
...
@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -75,7 +78,7 @@ class GLIDE(DiffusionPipeline):
)
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
transformer_out
,
clip_denoised
=
Tru
e
,
model_kwargs
=
Non
e
):
def
p_mean_variance
(
self
,
model
,
scheduler
,
x
,
t
,
transformer_out
=
None
,
low_res
=
None
,
clip_denoised
=
True
):
"""
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
the initial x, x_0.
...
@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
...
@@ -93,51 +96,60 @@ class GLIDE(DiffusionPipeline):
- 'log_variance': the log of 'variance'.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
- 'pred_xstart': the prediction for x_0.
"""
"""
if
model_kwargs
is
None
:
model_kwargs
=
{}
B
,
C
=
x
.
shape
[:
2
]
B
,
C
=
x
.
shape
[:
2
]
assert
t
.
shape
==
(
B
,)
assert
t
.
shape
==
(
B
,)
if
transformer_out
is
None
:
# super-res model
model_output
=
model
(
x
,
t
,
low_res
)
else
:
# text2image model
model_output
=
model
(
x
,
t
,
transformer_out
)
model_output
=
model
(
x
,
t
,
transformer_out
)
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
C
,
dim
=
1
)
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
C
,
dim
=
1
)
min_log
=
_extract_into_tensor
(
self
.
noise_
scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
min_log
=
_extract_into_tensor
(
scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
self
.
noise_
scheduler
.
betas
),
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
scheduler
.
betas
),
t
,
x
.
shape
)
# The model_var_values is [-1, 1] for [min_var, max_var].
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_variance
=
torch
.
exp
(
model_log_variance
)
model_variance
=
torch
.
exp
(
model_log_variance
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
scheduler
,
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
if
clip_denoised
:
if
clip_denoised
:
pred_xstart
=
pred_xstart
.
clamp
(
-
1
,
1
)
pred_xstart
=
pred_xstart
.
clamp
(
-
1
,
1
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
scheduler
,
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
assert
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
assert
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
return
model_mean
,
model_variance
,
model_log_variance
,
pred_xstart
return
model_mean
,
model_variance
,
model_log_variance
,
pred_xstart
def
_predict_xstart_from_eps
(
self
,
x_t
,
t
,
eps
):
def
_predict_xstart_from_eps
(
self
,
scheduler
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
assert
x_t
.
shape
==
eps
.
shape
return
(
return
(
_extract_into_tensor
(
self
.
noise_
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
_extract_into_tensor
(
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
self
.
noise_
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
-
_extract_into_tensor
(
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
)
)
def
_predict_eps_from_xstart
(
self
,
scheduler
,
x_t
,
t
,
pred_xstart
):
return
(
_extract_into_tensor
(
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
pred_xstart
)
/
_extract_into_tensor
(
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
):
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
text_
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
self
.
upscale_unet
.
to
(
torch_device
)
# Create a classifier-free guidance sampling function
# Create a classifier-free guidance sampling function
guidance_scale
=
3.0
guidance_scale
=
3.0
def
model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
def
text_
model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
model_out
=
self
.
text_
unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
...
@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
...
@@ -146,8 +158,8 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
# 1. Sample gaussian noise
batch_size
=
2
# second image is empty for classifier-free guidance
batch_size
=
2
# second image is empty for classifier-free guidance
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
self
.
text_
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
(
batch_size
,
self
.
text_
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
)
)
# 2. Encode tokens
# 2. Encode tokens
...
@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
...
@@ -157,14 +169,60 @@ class GLIDE(DiffusionPipeline):
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
num_timesteps
=
len
(
self
.
noise_scheduler
)
# 3. Run the text2image generation step
num_timesteps
=
len
(
self
.
text_noise_scheduler
)
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
model_fn
,
image
,
t
,
transformer_out
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
text_model_fn
,
self
.
text_noise_scheduler
,
image
,
t
,
transformer_out
=
transformer_out
)
noise
=
self
.
text_noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
# 4. Run the upscaling step
batch_size
=
1
image
=
image
[:
1
]
low_res
=
((
image
+
1
)
*
127.5
).
round
()
/
127.5
-
1
eta
=
0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp
=
0.997
image
=
self
.
upscale_noise_scheduler
.
sample_noise
(
(
batch_size
,
3
,
256
,
256
),
device
=
torch_device
,
generator
=
generator
)
*
upsample_temp
num_timesteps
=
len
(
self
.
upscale_noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
upscale_noise_scheduler
))),
total
=
len
(
self
.
upscale_noise_scheduler
)):
# i) define coefficients for time step t
clipped_image_coeff
=
1
/
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
clipped_coeff
=
torch
.
sqrt
(
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
upscale_noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
upscale_noise_scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
model_output
=
self
.
upscale_unet
(
image
,
time_input
,
low_res
)
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clipped_image_coeff
*
image
-
clipped_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clipped_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
self
.
upscale_noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
image
=
image
[
0
].
permute
(
1
,
2
,
0
)
image
=
image
[
0
].
permute
(
1
,
2
,
0
)
return
image
return
image
models/vision/glide/run_glide.py
View file @
74d2da99
import
torch
import
torch
from
diffusers
import
DiffusionPipeline
from
modeling_glide
import
GLIDE
import
PIL.Image
import
matplotlib
import
matplotlib.pyplot
as
plt
matplotlib
.
rcParams
[
'interactive'
]
=
True
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
model_id
=
"fusing/glide-base"
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
# load model and scheduler
pipeline
=
DiffusionPipeline
.
from_pretrained
(
model_id
)
# run inference (text-conditioned denoising + upscaling)
img
=
pipeline
(
"a clip art of a hugging face"
,
generator
)
img
=
pipeline
(
"an oil painting of a corgi"
,
generator
)
# process image to PIL
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
img
)
plt
.
figure
(
figsize
=
(
8
,
8
))
# save image
plt
.
imshow
(
img
)
image_pil
.
save
(
"test.png"
)
plt
.
show
()
\ No newline at end of file
src/diffusers/__init__.py
View file @
74d2da99
...
@@ -7,9 +7,10 @@ __version__ = "0.0.1"
...
@@ -7,9 +7,10 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNet
GLIDEModel
from
.models.unet_glide
import
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.vqvae
import
VQModel
from
.models.vqvae
import
VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.glide_ddim
import
GlideDDIMScheduler
src/diffusers/models/__init__.py
View file @
74d2da99
...
@@ -18,6 +18,6 @@
...
@@ -18,6 +18,6 @@
from
.clip_text_transformer
import
CLIPTextModel
from
.clip_text_transformer
import
CLIPTextModel
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNet
GLIDEModel
from
.unet_glide
import
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.vqvae
import
VQModel
from
.vqvae
import
VQModel
\ No newline at end of file
src/diffusers/models/unet_glide.py
View file @
74d2da99
...
@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
...
@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
UNet
GLIDEModel
(
ModelMixin
,
ConfigMixin
):
class
GLIDE
UNet
Model
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
...
@@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
,
in_channels
=
3
,
model_channels
,
model_channels
=
192
,
out_channels
,
out_channels
=
6
,
num_res_blocks
,
num_res_blocks
=
3
,
attention_resolutions
,
attention_resolutions
=
(
2
,
4
,
8
)
,
dropout
=
0
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
conv_resample
=
True
,
...
@@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
transformer_dim
=
512
,
transformer_dim
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
,
)
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
num_heads_upsample
=
num_heads
...
@@ -482,8 +463,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -482,8 +463,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear
(
time_embed_dim
,
time_embed_dim
),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
_feature_size
=
ch
self
.
_feature_size
=
ch
...
@@ -635,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -635,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
,
transformer_out
):
def
forward
(
self
,
x
,
timesteps
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
...
@@ -644,6 +623,91 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -644,6 +623,91 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
class
GLIDETextToImageUNetModel
(
GLIDEUNetModel
):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def
__init__
(
self
,
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
transformer_dim
=
512
):
super
().
__init__
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
)
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
...
@@ -663,3 +727,86 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -663,3 +727,86 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
module
(
h
,
emb
,
transformer_out
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
class
GLIDESuperResUNetModel
(
GLIDEUNetModel
):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def
__init__
(
self
,
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
):
super
().
__init__
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
)
self
.
register
(
in_channels
=
in_channels
,
model_channels
=
model_channels
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
attention_resolutions
=
attention_resolutions
,
dropout
=
dropout
,
channel_mult
=
channel_mult
,
conv_resample
=
conv_resample
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_fp16
=
use_fp16
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
)
def
forward
(
self
,
x
,
timesteps
,
low_res
=
None
):
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
h
=
x
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
return
self
.
out
(
h
)
\ No newline at end of file
src/diffusers/pipeline_utils.py
View file @
74d2da99
...
@@ -39,9 +39,10 @@ LOADABLE_CLASSES = {
...
@@ -39,9 +39,10 @@ LOADABLE_CLASSES = {
"CLIPTextModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
# TODO (Anton): move to transformers
"CLIPTextModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
# TODO (Anton): move to transformers
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
"GlideDDIMScheduler"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"
GPT2
Tokenizer"
:
[
"save_pretrained"
,
"from_pretrained"
],
"
PreTrained
Tokenizer"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
...
src/diffusers/schedulers/__init__.py
View file @
74d2da99
...
@@ -18,3 +18,4 @@
...
@@ -18,3 +18,4 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.glide_ddim
import
GlideDDIMScheduler
src/diffusers/schedulers/ddim.py
→
src/diffusers/schedulers/
glide_
ddim.py
View file @
74d2da99
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
math
import
numpy
as
np
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
...
@@ -22,36 +22,30 @@ from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
class
G
aussian
DD
P
MScheduler
(
nn
.
Module
,
ConfigMixin
):
class
G
lide
DD
I
MScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
def
__init__
(
self
,
self
,
timesteps
=
1000
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_
small"
,
variance_type
=
"fixed_
large"
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
variance_type
=
variance_type
,
)
)
self
.
num_timesteps
=
int
(
timesteps
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale
=
1000
/
self
.
num_timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment