Commit db2a1077 authored by anton-l's avatar anton-l
Browse files

Add glide text encoder

parent 9c4cd06d
...@@ -3,7 +3,8 @@ import argparse ...@@ -3,7 +3,8 @@ import argparse
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPTextConfig, CLIPTextModel, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
from modelling_text_encoder import CLIPTextModel
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu") state_dict = torch.load("base.pt", map_location="cpu")
...@@ -13,7 +14,8 @@ config = CLIPTextConfig( ...@@ -13,7 +14,8 @@ config = CLIPTextConfig(
intermediate_size=2048, intermediate_size=2048,
num_hidden_layers=16, num_hidden_layers=16,
num_attention_heads=8, num_attention_heads=8,
max_position_embeddings=128 max_position_embeddings=128,
use_padding_embeddings=True,
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>")
...@@ -30,15 +32,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] ...@@ -30,15 +32,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers): for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx] hf_layer = hf_encoder.encoder.layers[layer_idx]
q_proj, k_proj, v_proj = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"].chunk(3, dim=0) hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
q_proj_bias, k_proj_bias, v_proj_bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"].chunk(3, dim=0) hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.q_proj.weight.data = q_proj
hf_layer.self_attn.q_proj.bias.data = q_proj_bias
hf_layer.self_attn.k_proj.weight.data = k_proj
hf_layer.self_attn.k_proj.bias.data = k_proj_bias
hf_layer.self_attn.v_proj.weight.data = v_proj
hf_layer.self_attn.v_proj.bias.data = v_proj_bias
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment