Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
9c4cd06d
Unverified
Commit
9c4cd06d
authored
Jun 07, 2022
by
Anton Lozhkov
Committed by
GitHub
Jun 07, 2022
Browse files
Merge pull request #4 from huggingface/add-glide
Convert glide weights
parents
f39020bd
d04051e3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
0 deletions
+61
-0
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+60
-0
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+1
-0
No files found.
models/vision/glide/convert_weights.py
0 → 100644
View file @
9c4cd06d
import
argparse
import
torch
from
torch
import
nn
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict
=
torch
.
load
(
"base.pt"
,
map_location
=
"cpu"
)
state_dict
=
{
k
:
nn
.
Parameter
(
v
)
for
k
,
v
in
state_dict
.
items
()}
config
=
CLIPTextConfig
(
hidden_size
=
512
,
intermediate_size
=
2048
,
num_hidden_layers
=
16
,
num_attention_heads
=
8
,
max_position_embeddings
=
128
)
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/vocab.json"
,
"./glide-base/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
.
save_pretrained
(
"./glide-base"
)
hf_encoder
=
model
.
text_model
hf_encoder
.
embeddings
.
token_embedding
.
weight
=
state_dict
[
"token_embedding.weight"
]
hf_encoder
.
embeddings
.
position_embedding
.
weight
.
data
=
state_dict
[
"positional_embedding"
]
hf_encoder
.
embeddings
.
padding_embedding
.
weight
.
data
=
state_dict
[
"padding_embedding"
]
hf_encoder
.
final_layer_norm
.
weight
=
state_dict
[
"final_ln.weight"
]
hf_encoder
.
final_layer_norm
.
bias
=
state_dict
[
"final_ln.bias"
]
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
hf_layer
=
hf_encoder
.
encoder
.
layers
[
layer_idx
]
q_proj
,
k_proj
,
v_proj
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_qkv.weight"
].
chunk
(
3
,
dim
=
0
)
q_proj_bias
,
k_proj_bias
,
v_proj_bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_qkv.bias"
].
chunk
(
3
,
dim
=
0
)
hf_layer
.
self_attn
.
q_proj
.
weight
.
data
=
q_proj
hf_layer
.
self_attn
.
q_proj
.
bias
.
data
=
q_proj_bias
hf_layer
.
self_attn
.
k_proj
.
weight
.
data
=
k_proj
hf_layer
.
self_attn
.
k_proj
.
bias
.
data
=
k_proj_bias
hf_layer
.
self_attn
.
v_proj
.
weight
.
data
=
v_proj
hf_layer
.
self_attn
.
v_proj
.
bias
.
data
=
v_proj_bias
hf_layer
.
self_attn
.
out_proj
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_proj.weight"
]
hf_layer
.
self_attn
.
out_proj
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_proj.bias"
]
hf_layer
.
layer_norm1
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.ln_1.weight"
]
hf_layer
.
layer_norm1
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.ln_1.bias"
]
hf_layer
.
layer_norm2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.ln_2.weight"
]
hf_layer
.
layer_norm2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.ln_2.bias"
]
hf_layer
.
mlp
.
fc1
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_fc.weight"
]
hf_layer
.
mlp
.
fc1
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_fc.bias"
]
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
inputs
=
tokenizer
([
"an oil painting of a corgi"
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs
)
model
.
save_pretrained
(
"./glide-base"
)
\ No newline at end of file
models/vision/glide/run_glide.py
View file @
9c4cd06d
...
@@ -6,6 +6,7 @@ generator = torch.Generator()
...
@@ -6,6 +6,7 @@ generator = torch.Generator()
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment