Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d754ce5f
Commit
d754ce5f
authored
Jun 08, 2022
by
anton-l
Browse files
transformer-guided glide sampling
parent
07ffe73f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
29 deletions
+37
-29
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+2
-8
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+20
-4
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+1
-1
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+10
-12
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+4
-4
No files found.
models/vision/glide/convert_weights.py
View file @
d754ce5f
...
@@ -22,8 +22,7 @@ config = CLIPTextConfig(
...
@@ -22,8 +22,7 @@ config = CLIPTextConfig(
use_padding_embeddings
=
True
,
use_padding_embeddings
=
True
,
)
)
model
=
CLIPTextModel
(
config
).
eval
()
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/vocab.json"
,
"./glide-base/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
=
GPT2Tokenizer
(
"./glide-base/tokenizer/vocab.json"
,
"./glide-base/tokenizer/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
# tokenizer.save_pretrained("./glide-base")
hf_encoder
=
model
.
text_model
hf_encoder
=
model
.
text_model
...
@@ -52,12 +51,6 @@ for layer_idx in range(config.num_hidden_layers):
...
@@ -52,12 +51,6 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
# inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
# with torch.no_grad():
# outputs = model(**inputs)
# model.save_pretrained("./glide-base")
### Convert the UNet
### Convert the UNet
unet_model
=
UNetGLIDEModel
(
unet_model
=
UNetGLIDEModel
(
...
@@ -73,6 +66,7 @@ unet_model = UNetGLIDEModel(
...
@@ -73,6 +66,7 @@ unet_model = UNetGLIDEModel(
num_heads_upsample
=
1
,
num_heads_upsample
=
1
,
use_scale_shift_norm
=
True
,
use_scale_shift_norm
=
True
,
resblock_updown
=
True
,
resblock_updown
=
True
,
transformer_dim
=
512
,
)
)
unet_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
unet_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
...
...
models/vision/glide/modeling_glide.py
View file @
d754ce5f
...
@@ -130,21 +130,37 @@ class GLIDE(DiffusionPipeline):
...
@@ -130,21 +130,37 @@ class GLIDE(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
# Create a classifier-free guidance sampling function
guidance_scale
=
3.0
def
model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
eps
=
torch
.
cat
([
half_eps
,
half_eps
],
dim
=
0
)
return
torch
.
cat
([
eps
,
rest
],
dim
=
1
)
# 1. Sample gaussian noise
# 1. Sample gaussian noise
batch_size
=
2
# second image is empty for classifier-free guidance
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
self
.
noise_scheduler
.
sample_noise
(
(
1
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
(
batch_size
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
)
)
# 2. Encode tokens
# 2. Encode tokens
# an empty input is needed to guide the model away from (
# an empty input is needed to guide the model away from (
inputs
=
self
.
tokenizer
([
prompt
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
inputs
=
self
.
tokenizer
([
prompt
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
transformer_out
=
self
.
text_encoder
(
**
inputs
).
last_hidden_state
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
)
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
num_timesteps
=
len
(
self
.
noise_scheduler
)
num_timesteps
=
len
(
self
.
noise_scheduler
)
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
self
.
une
t
,
transformer_out
,
image
,
t
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
model_fn
,
image
,
t
,
transformer_out
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
...
...
models/vision/glide/run_glide.py
View file @
d754ce5f
...
@@ -9,6 +9,6 @@ generator = generator.manual_seed(0)
...
@@ -9,6 +9,6 @@ generator = generator.manual_seed(0)
# 1. Load models
# 1. Load models
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
img
=
pipeline
(
generator
)
img
=
pipeline
(
"an oil painting of a corgi"
,
generator
)
print
(
img
)
print
(
img
)
src/diffusers/models/unet_glide.py
View file @
d754ce5f
...
@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
encoder_channels
=
None
,
transformer_dim
=
512
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
...
@@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
encoder_channels
=
encoder_channels
,
transformer_dim
=
transformer_dim
,
)
)
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
...
@@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear
(
time_embed_dim
,
time_embed_dim
),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
_feature_size
=
ch
self
.
_feature_size
=
ch
...
@@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
encoder_channels
=
transformer_dim
,
)
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
@@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
encoder_channels
=
transformer_dim
,
),
),
ResBlock
(
ResBlock
(
ch
,
ch
,
...
@@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
encoder_channels
=
transformer_dim
,
)
)
)
)
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
...
@@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
...
@@ -655,13 +653,13 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -655,13 +653,13 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
emb
=
emb
+
transformer_proj
.
to
(
emb
)
emb
=
emb
+
transformer_proj
.
to
(
emb
)
h
=
x
.
type
(
self
.
dtype
)
h
=
x
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
module
(
h
,
emb
,
transformer_out
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
transformer_out
)
h
=
self
.
middle_block
(
h
,
emb
,
transformer_out
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
other
=
hs
.
pop
()
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
src/diffusers/schedulers/classifier_free_guidance.py
View file @
d754ce5f
...
@@ -65,14 +65,14 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
...
@@ -65,14 +65,14 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
if
beta_schedule
==
"squaredcos_cap_v2"
:
if
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod_prev
=
np
.
append
(
1.0
,
self
.
alphas_cumprod
[:
-
1
])
self
.
alphas_cumprod_prev
=
np
.
append
(
1.0
,
self
.
alphas_cumprod
[:
-
1
])
...
@@ -81,12 +81,12 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
...
@@ -81,12 +81,12 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
posterior_variance
=
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_variance
=
self
.
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
posterior_log_variance_clipped
=
np
.
log
(
self
.
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
)
self
.
posterior_mean_coef1
=
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef1
=
self
.
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef2
=
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef2
=
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment