Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f9cdb4dd
Commit
f9cdb4dd
authored
Jun 08, 2022
by
anton-l
Browse files
Convert glide upsampling weights
parent
43e728d3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
17 deletions
+95
-17
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+34
-6
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+2
-2
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+1
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-1
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+56
-6
No files found.
models/vision/glide/convert_weights.py
View file @
f9cdb4dd
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
UNet
GLIDEModel
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
modeling_glide
import
GLIDE
from
modeling_glide
import
GLIDE
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
...
@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
...
@@ -51,9 +51,9 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
### Convert the UNet
### Convert the
Text-to-Image
UNet
unet
_model
=
UNet
GLIDEModel
(
text2im
_model
=
GLIDE
TextToImageUNet
Model
(
in_channels
=
3
,
in_channels
=
3
,
model_channels
=
192
,
model_channels
=
192
,
out_channels
=
6
,
out_channels
=
6
,
...
@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
...
@@ -69,10 +69,38 @@ unet_model = UNetGLIDEModel(
transformer_dim
=
512
,
transformer_dim
=
512
,
)
)
unet
_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
text2im
_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
text_
scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
glide
=
GLIDE
(
unet
=
unet_model
,
noise_scheduler
=
scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
)
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
state_dict
=
torch
.
load
(
"upsample.pt"
,
map_location
=
"cpu"
)
superres_model
=
GLIDESuperResUNetModel
(
in_channels
=
6
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
2
,
attention_resolutions
=
(
8
,
16
,
32
),
dropout
=
0.1
,
channel_mult
=
(
1
,
1
,
2
,
2
,
4
,
4
),
num_heads
=
1
,
num_head_channels
=
64
,
num_heads_upsample
=
1
,
use_scale_shift_norm
=
True
,
resblock_updown
=
True
,
)
superres_model
.
load_state_dict
(
state_dict
)
upscale_scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
glide
=
GLIDE
(
text_unet
=
text2im_model
,
text_noise_scheduler
=
text_scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
,
upscale_unet
=
superres_model
,
upscale_noise_scheduler
=
scheduler
)
glide
.
save_pretrained
(
"./glide-base"
)
glide
.
save_pretrained
(
"./glide-base"
)
models/vision/glide/modeling_glide.py
View file @
f9cdb4dd
...
@@ -18,7 +18,7 @@ import numpy as np
...
@@ -18,7 +18,7 @@ import numpy as np
import
torch
import
torch
import
tqdm
import
tqdm
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
UNet
GLIDEModel
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
transformers
import
GPT2Tokenizer
from
transformers
import
GPT2Tokenizer
...
@@ -41,7 +41,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
...
@@ -41,7 +41,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class
GLIDE
(
DiffusionPipeline
):
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
def
__init__
(
self
,
self
,
unet
:
UNet
GLIDEModel
,
unet
:
GLIDE
TextToImageUNet
Model
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
tokenizer
:
GPT2Tokenizer
,
...
...
models/vision/glide/run_glide.py
View file @
f9cdb4dd
...
@@ -12,7 +12,7 @@ generator = generator.manual_seed(0)
...
@@ -12,7 +12,7 @@ generator = generator.manual_seed(0)
# 1. Load models
# 1. Load models
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
img
=
pipeline
(
"a
n oil painting
of a corgi"
,
generator
)
img
=
pipeline
(
"a
pencil sketch
of a corgi"
,
generator
)
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
img
=
((
img
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
plt
.
figure
(
figsize
=
(
8
,
8
))
plt
.
figure
(
figsize
=
(
8
,
8
))
...
...
src/diffusers/__init__.py
View file @
f9cdb4dd
...
@@ -7,7 +7,7 @@ __version__ = "0.0.1"
...
@@ -7,7 +7,7 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNet
GLIDEModel
from
.models.unet_glide
import
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
...
...
src/diffusers/models/__init__.py
View file @
f9cdb4dd
...
@@ -18,5 +18,5 @@
...
@@ -18,5 +18,5 @@
from
.clip_text_transformer
import
CLIPTextModel
from
.clip_text_transformer
import
CLIPTextModel
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNet
GLIDEModel
from
.unet_glide
import
GLIDE
TextToImageUNetModel
,
GLIDESuperResUNet
Model
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
src/diffusers/models/unet_glide.py
View file @
f9cdb4dd
...
@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
...
@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return
a
.
reshape
(
bs
,
-
1
,
length
)
return
a
.
reshape
(
bs
,
-
1
,
length
)
class
UNet
GLIDEModel
(
ModelMixin
,
ConfigMixin
):
class
GLIDE
UNet
Model
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
...
@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
transformer_dim
=
512
,
transformer_dim
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
...
@@ -455,7 +455,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -455,7 +455,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
transformer_dim
=
transformer_dim
,
)
)
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
...
@@ -482,8 +481,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -482,8 +481,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear
(
time_embed_dim
,
time_embed_dim
),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
_feature_size
=
ch
self
.
_feature_size
=
ch
...
@@ -635,7 +632,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -635,7 +632,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
,
transformer_out
):
def
forward
(
self
,
x
,
timesteps
,
y
=
None
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
...
@@ -644,6 +641,42 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -644,6 +641,42 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
if
self
.
num_classes
is
not
None
:
assert
y
.
shape
==
(
x
.
shape
[
0
],)
emb
=
emb
+
self
.
label_emb
(
y
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
class
GLIDETextToImageUNetModel
(
GLIDEUNetModel
):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
transformer_proj
=
nn
.
Linear
(
kwargs
[
"transformer_dim"
],
self
.
model_channels
*
4
)
def
forward
(
self
,
x
,
timesteps
,
transformer_out
=
None
):
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
...
@@ -663,3 +696,20 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -663,3 +696,20 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
module
(
h
,
emb
,
transformer_out
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
class
GLIDESuperResUNetModel
(
GLIDEUNetModel
):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
,
timesteps
,
low_res
=
None
,
**
kwargs
):
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
x
=
torch
.
cat
([
x
,
upsampled
],
dim
=
1
)
return
super
().
forward
(
x
,
timesteps
,
**
kwargs
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment