Commit 73557d95 authored by yuguo960516's avatar yuguo960516
Browse files

glm

parents
Pipeline #148 failed with stages
in 0 seconds
import os
import sys
import unittest
import oneflow as flow
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from clip.model import CLIP, ModifiedResNet, Transformer # noqa: E402
class TestCLIP(unittest.TestCase):
def test_modified_resnet(self):
net = ModifiedResNet([3, 4, 6, 3], 120, 16).to_global(
sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
x = flow.rand(
32, 3, 224, 224, sbp=flow.sbp.split(0), placement=flow.placement("cuda", ranks=[0])
)
y = net(x)
assert isinstance(y, flow.Tensor)
def test_transformer(self):
mask = flow.ones(
12, 12, sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0])
)
mask = flow.tril(mask) # zero out the lower diagonal
# [1, 1, s, s]
mask = mask.unsqueeze(0).unsqueeze(1).expand(16, 1, 12, 12)
net = Transformer(128, 10, 16, mask)
x = flow.rand(
16, 12, 128, sbp=flow.sbp.split(0), placement=flow.placement("cuda", ranks=[0])
)
y = net(x)
assert isinstance(y, flow.Tensor)
def test_clip(self):
# clip with resnet
net = CLIP(
embed_dim=10,
# vision
image_resolution=224,
vision_layers=6,
vision_width=120,
vision_patch_size=16,
# text
context_length=24,
vocab_size=3000,
transformer_width=128,
transformer_heads=16,
transformer_layers=10,
)
img = flow.rand(
16, 3, 224, 224, sbp=flow.sbp.split(0), placement=flow.placement("cuda", ranks=[0])
)
text = flow.ones(
16,
24,
dtype=flow.int,
sbp=flow.sbp.split(0),
placement=flow.placement("cuda", ranks=[0]),
)
logits_img, logits_text = net(img, text)
print(logits_img)
print(logits_text)
if __name__ == "__main__":
unittest.main()
import os
import sys
import unittest
import numpy as np
import oneflow as flow
import torch
from torch.nn.functional import multi_head_attention_forward as multi_head_attention_forward_torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from clip.ops import multi_head_attention_forward as multi_head_attention_forward_flow # noqa: E402
class TestMultiHeadAttention(unittest.TestCase):
def test_with_torch(self):
k_proj_weight = np.random.normal(size=(32, 32))
k_proj_bias = np.random.normal(size=(32))
q_proj_weight = np.random.normal(size=(32, 32))
q_proj_bias = np.random.normal(size=(32))
v_proj_weight = np.random.normal(size=(32, 32))
v_proj_bias = np.random.normal(size=(32))
c_proj_weight = np.random.normal(size=(64, 32))
c_proj_bias = np.random.normal(size=(64))
x = np.random.normal(size=(65, 16, 32))
x_torch = torch.from_numpy(x)
torch_out, _ = multi_head_attention_forward_torch(
query=x_torch,
key=x_torch,
value=x_torch,
embed_dim_to_check=x_torch.shape[-1],
num_heads=8,
q_proj_weight=torch.from_numpy(q_proj_weight),
k_proj_weight=torch.from_numpy(k_proj_weight),
v_proj_weight=torch.from_numpy(v_proj_weight),
in_proj_weight=None,
in_proj_bias=torch.cat(
[
torch.from_numpy(q_proj_bias),
torch.from_numpy(k_proj_bias),
torch.from_numpy(v_proj_bias),
]
),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=torch.from_numpy(c_proj_weight),
out_proj_bias=torch.from_numpy(c_proj_bias),
use_separate_proj_weight=True,
training=True,
need_weights=False,
)
x_flow = flow.from_numpy(x).cuda()
flow_out, _ = multi_head_attention_forward_flow(
query=x_flow,
key=x_flow,
value=x_flow,
embed_dim_to_check=x_flow.shape[-1],
num_heads=8,
q_proj_weight=flow.from_numpy(q_proj_weight).cuda(),
k_proj_weight=flow.from_numpy(k_proj_weight).cuda(),
v_proj_weight=flow.from_numpy(v_proj_weight).cuda(),
in_proj_weight=None,
in_proj_bias=flow.cat(
[
flow.from_numpy(q_proj_bias).cuda(),
flow.from_numpy(k_proj_bias).cuda(),
flow.from_numpy(v_proj_bias).cuda(),
]
),
bias_k=None,
bias_v=None,
dropout_p=0,
out_proj_weight=flow.from_numpy(c_proj_weight).cuda(),
out_proj_bias=flow.from_numpy(c_proj_bias).cuda(),
use_separate_proj_weight=True,
training=True,
need_weights=False,
)
assert np.allclose(torch_out.numpy(), flow_out.numpy())
if __name__ == "__main__":
unittest.main()
import os
import sys
dir_path = os.path.abspath(os.path.dirname(__file__))
dir_path = "/".join(dir_path.split("/")[:-1])
sys.path.append(dir_path)
from omegaconf import OmegaConf # noqa
from dataset.dataset import CoupletsDataset # noqa
from modeling.model import Seq2Seq # noqa
from libai.config import get_config # noqa
from libai.config import LazyCall # noqa
from libai.data.build import build_nlp_train_loader, build_nlp_test_loader # noqa
optim = get_config("common/optim.py").optim
graph = get_config("common/models/graph.py").graph
train = get_config("common/train.py").train
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(CoupletsDataset)(
path="data_test/couplets",
is_train=True,
maxlen=64,
)
],
num_workers=4,
)
dataloader.test = [
LazyCall(build_nlp_test_loader)(
dataset=LazyCall(CoupletsDataset)(
path="data_test/couplets",
is_train=False,
maxlen=64,
),
num_workers=4,
)
]
transformer_cfg = dict(
vocab_size=9027,
max_position_embeddings=64,
hidden_size=512,
intermediate_size=512,
hidden_layers=6,
num_attention_heads=8,
embedding_dropout_prob=0.1,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
initializer_range=0.02,
layernorm_epsilon=1e-5,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
)
model = LazyCall(Seq2Seq)(cfg=transformer_cfg)
train.update(
dict(
rdma_enabled=False,
activation_checkpoint=dict(enabled=False),
amp=dict(enabled=False),
output_dir="output/couplet/",
train_micro_batch_size=128,
test_micro_batch_size=32,
train_epoch=20,
train_iter=0,
log_period=10,
warmup_ratio=0.01,
dist=dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
pipeline_stage_id=None,
pipeline_num_layers=model.cfg.hidden_layers * 2,
),
evaluation=dict(
enabled=False,
),
)
)
import os
import oneflow as flow
from dataset.mask import make_padding_mask, make_sequence_mask
from oneflow.utils.data import Dataset
from tokenizer.tokenizer import CoupletsTokenizer
from libai.data.structures import DistTensorData, Instance
class CoupletsDataset(Dataset):
def __init__(self, path, is_train=True, maxlen=64):
if is_train:
datapath = os.path.join(path, "train")
else:
datapath = os.path.join(path, "test")
src = []
with open(f"{datapath}/in.txt", "r") as f_src:
for line in f_src.readlines():
src.append(line.strip("\n"))
tgt = []
with open(f"{datapath}/out.txt", "r") as f_tgt:
for line in f_tgt.readlines():
tgt.append(line.strip("\n"))
self.data = list(zip(src, tgt))
self.tokenizer = CoupletsTokenizer(f"{path}/vocab.txt")
self.maxlen = maxlen
self.unk_id = self.tokenizer.unk_id
self.pad_id = self.tokenizer.pad_id
self.bos_id = self.tokenizer.bos_id
self.eos_id = self.tokenizer.eos_id
def __len__(self):
return len(self.data)
def text2ids(self, text):
tokens = self.tokenizer.tokenize(text)
ids = self.tokenizer.convert_tokens_to_ids(tokens)
ids = ids[: self.maxlen - 2]
ids = [self.bos_id] + ids + [self.eos_id]
ids = ids + [self.pad_id] * (self.maxlen - len(ids))
return ids
def __getitem__(self, index):
sample = self.data[index]
src_ids = self.text2ids(sample[0])
tgt_ids = self.text2ids(sample[1])
encoder_self_attn_mask = make_padding_mask(src_ids, src_ids, self.pad_id)
decoder_self_attn_mask = make_padding_mask(
tgt_ids, tgt_ids, self.pad_id
) * make_sequence_mask(tgt_ids)
cross_attn_mask = make_padding_mask(tgt_ids, src_ids, self.pad_id)
return Instance(
encoder_input_ids=DistTensorData(flow.tensor(src_ids, dtype=flow.long)),
decoder_input_ids=DistTensorData(flow.tensor(tgt_ids, dtype=flow.long)),
encoder_attn_mask=DistTensorData(flow.tensor(encoder_self_attn_mask, dtype=flow.long)),
decoder_attn_mask=DistTensorData(flow.tensor(decoder_self_attn_mask, dtype=flow.long)),
encoder_decoder_attn_mask=DistTensorData(flow.tensor(cross_attn_mask, dtype=flow.long)),
)
import numpy as np
def make_padding_mask(q_ids, kv_ids, pad_id):
q = (np.array(q_ids) != pad_id).reshape(-1, 1)
kv = (np.array(kv_ids) != pad_id).reshape(1, -1)
padding_mask = (q * kv).astype(float)
return padding_mask
def make_sequence_mask(ids):
seqlen = len(ids)
sequence_mask = np.triu(np.ones((seqlen, seqlen))).transpose()
return sequence_mask
import os
import sys
dir_path = os.path.abspath(os.path.dirname(__file__)) # noqa
sys.path.append(dir_path) # noqa
import oneflow as flow # noqa
from dataset.mask import make_sequence_mask # noqa
from tokenizer.tokenizer import CoupletsTokenizer # noqa
from libai.data.structures import DistTensorData # noqa
from libai.inference.basic import BasePipeline # noqa
from libai.utils import distributed as dist # noqa
def get_global_tensor(rawdata):
t = flow.tensor(rawdata, dtype=flow.long).unsqueeze(0)
dtd = DistTensorData(t)
dtd.to_global()
return dtd.tensor
class CoupletPipeline(BasePipeline):
def _parse_parameters(self, **pipeline_parameters):
preprocess_params = {**pipeline_parameters}
forward_params = {}
postprocess_params = {}
return preprocess_params, forward_params, postprocess_params
def update_cfg(
self,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_stage_id=None,
pipeline_num_layers=None,
):
super().update_cfg(
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
pipeline_num_layers,
)
self.cfg.vocab_file = "data_test/couplets/vocab.txt"
def build_tokenizer(self, cfg):
return CoupletsTokenizer(cfg.vocab_file)
def generate(self, sentence):
# Encode
sentence = " ".join([word for word in sentence])
tokens_list = self.tokenizer.tokenize(sentence)
encoder_ids_list = (
[self.tokenizer.bos_id]
+ self.tokenizer.convert_tokens_to_ids(tokens_list)
+ [self.tokenizer.eos_id]
)
seq_len = len(encoder_ids_list)
encoder_input_ids = get_global_tensor(encoder_ids_list)
encoder_states = self.model.encode(encoder_input_ids, None)
# Decode
decoder_ids_list = [self.tokenizer.bos_id]
decoder_input_ids = get_global_tensor(decoder_ids_list)
for i in range(seq_len + 10):
mask_array = make_sequence_mask(decoder_ids_list)
decoder_attn_mask = get_global_tensor(mask_array)
logits = self.model.decode(decoder_input_ids, decoder_attn_mask, encoder_states, None)
prob = logits[:, -1]
_, next_word = flow.max(prob, dim=1)
next_word = next_word.item()
decoder_ids_list = decoder_ids_list + [next_word]
decoder_input_ids = get_global_tensor(decoder_ids_list)
if next_word == self.tokenizer.eos_id:
break
result_tokens_list = self.tokenizer.convert_ids_to_tokens(decoder_ids_list)
return "".join(result_tokens_list).replace("<bos>", "").replace("<eos>", "")
def preprocess(self, sentence) -> dict:
input_dict = {"sentence": sentence}
return input_dict
def forward(self, input_dict) -> dict:
model_output = self.generate(input_dict["sentence"])
model_out_dict = {"下联": model_output}
return model_out_dict
def postprocess(self, model_out_dict) -> dict:
return model_out_dict
if __name__ == "__main__":
pipeline = CoupletPipeline(
"projects/Couplets/configs/config.py",
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=4,
pipeline_stage_id=None,
pipeline_num_layers=12,
model_path="output/couplet/model_final/model",
mode="libai",
)
out = pipeline("滚滚长江东逝水")
if dist.is_main_process():
print(out)
import os
import sys
dir_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(dir_path)
import oneflow as flow # noqa
from dataset.mask import make_padding_mask, make_sequence_mask # noqa
from modeling.model import Seq2Seq # noqa
from tokenizer.tokenizer import CoupletsTokenizer # noqa
from libai.config import LazyConfig # noqa
from libai.data.structures import DistTensorData # noqa
from libai.engine.default import DefaultTrainer # noqa
from libai.utils.checkpoint import Checkpointer # noqa
def get_global_tensor(rawdata):
t = flow.tensor(rawdata, dtype=flow.long).unsqueeze(0)
dtd = DistTensorData(t)
dtd.to_global()
return dtd.tensor
class GeneratorForEager:
def __init__(self, config_file, checkpoint_file, vocab_file):
cfg = LazyConfig.load(config_file)
self.model = DefaultTrainer.build_model(cfg).eval()
Checkpointer(self.model).load(checkpoint_file)
self.tokenizer = CoupletsTokenizer(vocab_file)
def infer(self, sentence):
# Encode
sentence = " ".join([word for word in sentence])
tokens_list = self.tokenizer.tokenize(sentence)
encoder_ids_list = (
[self.tokenizer.bos_id]
+ self.tokenizer.convert_tokens_to_ids(tokens_list)
+ [self.tokenizer.eos_id]
)
seq_len = len(encoder_ids_list)
encoder_input_ids = get_global_tensor(encoder_ids_list)
encoder_states = self.model.encode(encoder_input_ids, None)
# Decode
decoder_ids_list = [self.tokenizer.bos_id]
decoder_input_ids = get_global_tensor(decoder_ids_list)
for i in range(seq_len + 10):
mask_array = make_sequence_mask(decoder_ids_list)
decoder_attn_mask = get_global_tensor(mask_array)
logits = self.model.decode(decoder_input_ids, decoder_attn_mask, encoder_states, None)
prob = logits[:, -1]
_, next_word = flow.max(prob, dim=1)
next_word = next_word.item()
decoder_ids_list = decoder_ids_list + [next_word]
decoder_input_ids = get_global_tensor(decoder_ids_list)
if next_word == self.tokenizer.eos_id:
break
result_tokens_list = self.tokenizer.convert_ids_to_tokens(decoder_ids_list)
return "".join(result_tokens_list).replace("<bos>", "").replace("<eos>", "")
if __name__ == "__main__":
config_file = "output/couplet/config.yaml"
checkpoint_file = "output/couplet/model_final"
vocab_file = "data_test/couplets/vocab.txt"
generator = GeneratorForEager(config_file, checkpoint_file, vocab_file)
sentence = input("上联:\n")
result = generator.infer(sentence)
print("下联:\n" + result)
import oneflow as flow
from oneflow import nn
from libai.layers.cross_entropy import ParallelCrossEntropyLoss
from libai.utils import distributed as dist
from .transformer_model import TransformerModel
class Seq2SeqLoss(nn.Module):
def __init__(self):
super().__init__()
self.lm_loss = ParallelCrossEntropyLoss()
def forward(self, logits, lm_labels):
logits = logits[:, :-1, :]
lm_labels = lm_labels[:, 1:]
lm_loss = self.lm_loss(logits, lm_labels)
lm_loss = lm_loss.mean()
return lm_loss
class Seq2Seq(nn.Module):
def __init__(self, cfg):
super().__init__()
self.language_model = TransformerModel(cfg)
self.loss_func = Seq2SeqLoss()
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
):
logits = self.language_model(
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
)
if self.training:
loss = self.loss_func(logits, decoder_input_ids)
return {"total_loss": loss}
logits = logits.view(-1, logits.shape[-1])
return {"prediction_scores": logits}
def encode(
self,
encoder_input_ids,
encoder_attn_mask,
):
encoder_input_embeddings = self.language_model.embedding(encoder_input_ids)
if encoder_attn_mask is not None:
encoder_extended_attn_mask = self.language_model.extended_attn_mask(encoder_attn_mask)
encoder_states = self.language_model.encoder(
encoder_input_embeddings,
encoder_extended_attn_mask,
)
else:
encoder_states = self.language_model.encoder(
encoder_input_embeddings,
None,
)
return encoder_states
def decode(
self,
decoder_input_ids,
decoder_attn_mask,
encoder_states,
encoder_decoder_attn_mask,
):
decoder_input_embeddings = self.language_model.embedding(decoder_input_ids)
decoder_extended_attn_mask = self.language_model.extended_attn_mask(decoder_attn_mask)
if encoder_decoder_attn_mask is not None:
encoder_decoder_extended_attn_mask = self.language_model.extended_attn_mask(
encoder_decoder_attn_mask
)
decoder_states = self.language_model.decoder(
decoder_input_embeddings,
decoder_extended_attn_mask,
encoder_states,
encoder_decoder_extended_attn_mask,
)
else:
decoder_states = self.language_model.decoder(
decoder_input_embeddings,
decoder_extended_attn_mask,
encoder_states,
None,
)
logits = self.language_model.lm_head(decoder_states)
return logits
@staticmethod
def set_pipeline_stage_id(model):
dist_utils = dist.get_dist_util()
from .transformer_model import ExtendedMask, TransformerEmbedding, TransformerLayer
# Set pipeline parallelism stage_id
if hasattr(model.language_model.lm_head, "config"):
# Old API in OneFlow 0.8
for module_block in model.modules():
# module.origin can get the original module
if isinstance(module_block.origin, TransformerEmbedding):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, ExtendedMask):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.origin, TransformerLayer):
module_block.config.set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# Set the lm_head stage id
model.language_model.lm_head.config.set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
else:
for module_block in model.modules():
if isinstance(module_block.to(nn.Module), TransformerEmbedding):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), ExtendedMask):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
)
elif isinstance(module_block.to(nn.Module), TransformerLayer):
module_block.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(module_block.layer_idx),
dist.get_layer_placement(module_block.layer_idx),
)
# Set the lm_head stage id
model.language_model.lm_head.to(flow.nn.graph.GraphModule).set_stage(
dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
)
import math
import oneflow as flow
from oneflow import nn
from libai.config import configurable
from libai.layers import (
LayerNorm,
Linear,
SinePositionalEmbedding,
TransformerLayer,
VocabEmbedding,
)
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist
class ExtendedMask(flow.nn.Module):
def forward(self, x):
return x.unsqueeze(1)
class TransformerEmbedding(nn.Module):
def __init__(
self,
vocab_size,
hidden_size,
max_sequence_length,
embedding_dropout_prob,
init_method=nn.init.xavier_normal_,
):
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.word_embedding = VocabEmbedding(vocab_size, hidden_size, init_method=init_method)
self.positional_encoding = SinePositionalEmbedding(max_sequence_length, hidden_size)
self.position_ids = flow.arange(
max_sequence_length,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
).unsqueeze(0)
self.embedding_dropout = nn.Dropout(embedding_dropout_prob)
def forward(self, input_ids):
seq_length = input_ids.size()[1]
word_embeddings = self.word_embedding(input_ids)
position_ids = (
self.position_ids[:, :seq_length].expand_as(input_ids).to_global(sbp=input_ids.sbp)
)
positional_encodings = self.positional_encoding(position_ids)
embeddings = word_embeddings * math.sqrt(self.hidden_size) + positional_encodings
embeddings = self.embedding_dropout(embeddings)
return embeddings
class TransformerEncoder(nn.Module):
def __init__(
self,
hidden_size=512,
ffn_hidden_size=512,
hidden_layers=6,
num_attention_heads=8,
is_decoder=False,
initializer_range=0.02,
attention_dropout_prob=0.1,
output_dropout_prob=0.1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
):
super().__init__()
self.encoder_layers = nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
is_decoder=is_decoder,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
init_method=init_method_normal(initializer_range),
output_layer_init_method=scaled_init_method_normal(
initializer_range, hidden_layers
),
layer_idx=i,
)
for i in range(hidden_layers)
]
)
self.encoder_final_layernorm = LayerNorm(
(hidden_size,), eps=layernorm_epsilon, layer_idx=hidden_layers - 1
)
def forward(self, encoder_input_embeddings, encoder_extended_attn_mask):
enc_hidden_states = encoder_input_embeddings
for layer in self.encoder_layers:
enc_hidden_states = layer(enc_hidden_states, encoder_extended_attn_mask)
encoder_states = self.encoder_final_layernorm(enc_hidden_states)
return encoder_states
class TransformerDecoder(nn.Module):
def __init__(
self,
hidden_size=512,
ffn_hidden_size=512,
hidden_layers=6,
num_attention_heads=8,
is_decoder=True,
initializer_range=0.02,
attention_dropout_prob=0.1,
output_dropout_prob=0.1,
layernorm_epsilon=1e-5,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
):
super().__init__()
self.decoder_layers = nn.ModuleList(
[
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
is_decoder=is_decoder,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
init_method=init_method_normal(initializer_range),
output_layer_init_method=scaled_init_method_normal(
initializer_range, hidden_layers
),
layer_idx=i,
)
for i in range(hidden_layers, 2 * hidden_layers)
]
)
self.decoder_final_layernorm = LayerNorm(
(hidden_size,), eps=layernorm_epsilon, layer_idx=2 * hidden_layers - 1
)
def forward(
self,
decoder_input_embeddings,
decoder_extended_attn_mask,
encoder_states,
encoder_decoder_extended_attn_mask,
):
dec_hidden_states = decoder_input_embeddings
for layer in self.decoder_layers:
dec_hidden_states = layer(
dec_hidden_states,
decoder_extended_attn_mask,
encoder_states,
encoder_decoder_extended_attn_mask,
)
decoder_states = self.decoder_final_layernorm(dec_hidden_states)
return decoder_states
class TransformerModel(nn.Module):
@configurable
def __init__(
self,
vocab_size,
max_position_embeddings,
hidden_size=512,
intermediate_size=512,
hidden_layers=6,
num_attention_heads=8,
embedding_dropout_prob=0.1,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
initializer_range=0.02,
layernorm_epsilon=1e-5,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
):
super().__init__()
self.embedding = TransformerEmbedding(
vocab_size,
hidden_size,
max_position_embeddings,
embedding_dropout_prob,
init_method=init_method_normal(initializer_range),
)
self.extended_attn_mask = ExtendedMask()
self.encoder = TransformerEncoder(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
hidden_layers=hidden_layers,
num_attention_heads=num_attention_heads,
is_decoder=False,
initializer_range=0.02,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
)
self.decoder = TransformerDecoder(
hidden_size=hidden_size,
ffn_hidden_size=intermediate_size,
hidden_layers=hidden_layers,
num_attention_heads=num_attention_heads,
is_decoder=True,
initializer_range=0.02,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=hidden_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
bias_gelu_fusion=bias_gelu_fusion,
bias_dropout_fusion=bias_dropout_fusion,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
)
self.lm_head = Linear(
hidden_size,
vocab_size,
layer_idx=-1,
)
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"max_position_embeddings": cfg.max_position_embeddings,
"hidden_size": cfg.hidden_size,
"intermediate_size": cfg.intermediate_size,
"hidden_layers": cfg.hidden_layers,
"num_attention_heads": cfg.num_attention_heads,
"embedding_dropout_prob": cfg.embedding_dropout_prob,
"hidden_dropout_prob": cfg.hidden_dropout_prob,
"attention_dropout_prob": cfg.attention_dropout_prob,
"initializer_range": cfg.initializer_range,
"layernorm_epsilon": cfg.layernorm_epsilon,
"bias_gelu_fusion": cfg.bias_gelu_fusion,
"bias_dropout_fusion": cfg.bias_dropout_fusion,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
"apply_query_key_layer_scaling": cfg.apply_query_key_layer_scaling,
}
def forward(
self,
encoder_input_ids,
decoder_input_ids,
encoder_attn_mask,
decoder_attn_mask,
encoder_decoder_attn_mask,
):
encoder_input_embeddings = self.embedding(encoder_input_ids)
decoder_input_embeddings = self.embedding(decoder_input_ids)
encoder_extended_attn_mask = self.extended_attn_mask(encoder_attn_mask)
decoder_extended_attn_mask = self.extended_attn_mask(decoder_attn_mask)
encoder_decoder_extended_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask)
encoder_states = self.encoder(encoder_input_embeddings, encoder_extended_attn_mask)
decoder_states = self.decoder(
decoder_input_embeddings,
decoder_extended_attn_mask,
encoder_states,
encoder_decoder_extended_attn_mask,
)
logits = self.lm_head(decoder_states)
return logits
## Couplets in LiBai
Contributor{Yulin Zhuang: [https://github.com/ZylOo0](https://github.com/ZylOo0)}
This is the LiBai implementation of Couplets
## Supported parallel mode and task
Based on [libai.layers](https://libai.readthedocs.io/en/latest/modules/libai.layers.html), Couplets model is automatically configured with the following parallelism mode.
<table class="docutils">
<tbody>
<tr>
<th width="80"> Model </th>
<th valign="bottom" align="left" width="120">Data Parallel</th>
<th valign="bottom" align="left" width="120">Tensor Parallel</th>
<th valign="bottom" align="left" width="120">Pipeline Parallel</th>
</tr>
<tr>
<td align="left"> <b> Couplets training </b> </td>
<td align="left">&#10004;</td>
<td align="left">&#10004;</td>
<td align="left">&#10004;</td>
</tr>
<tr>
<td align="left"> <b> Couplets inference </b> </td>
<td align="left"> - </td>
<td align="left">&#10004;</td>
<td align="left">&#10004;</td>
</tr>
</tbody>
</table>
## Setup env
Install in LiBai, refer to [LiBai install doc](https://libai.readthedocs.io/en/latest/tutorials/get_started/Installation.html)
## Prepare Datasets
Download datasets and unzip data:
```shell
wget http://oneflow-static.oss-cn-beijing.aliyuncs.com/libai/couplets/couplets.zip
unzip couplets.zip
```
you will get dataset like this:
```
couplets
├── test
│ ├── in.txt
│ └── out.txt
├── train
│ ├── in.txt
│ └── out.txt
└── vocab.txt
```
## Training
- Set dataset path in `configs/config.py`
```python
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(CoupletsDataset)(
path="data_test/couplets", # set to your data_path
is_train=True,
maxlen=64,
)
],
num_workers=4,
)
dataloader.test = [
LazyCall(build_nlp_test_loader)(
dataset=LazyCall(CoupletsDataset)(
path="data_test/couplets", # set to your data_path
is_train=False,
maxlen=64,
),
num_workers=4,
)
]
```
- set model cfg in `configs/config.py` according to your needs
```python
transformer_cfg = dict(
vocab_size=9027,
max_position_embeddings=64,
hidden_size=512, # modify it according to your needs
intermediate_size=512, # modify it according to your needs
hidden_layers=6, # modify it according to your needs
num_attention_heads=8, # modify it according to your needs
embedding_dropout_prob=0.1,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
initializer_range=0.02,
layernorm_epsilon=1e-5,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=True,
)
```
- set dist config in `configs/config.py` according to your needs, refer to [LiBai distribute doc](https://libai.readthedocs.io/en/latest/tutorials/basics/Distributed_Configuration.html) for more details
```python
dist=dict(
data_parallel_size=1, # modify it according to your needs
tensor_parallel_size=1, # modify it according to your needs
pipeline_parallel_size=1, # modify it according to your needs
pipeline_stage_id=None, # modify it according to your needs
pipeline_num_layers=model.cfg.hidden_layers * 2,
),
```
- Following [quick_run](https://libai.readthedocs.io/en/latest/tutorials/get_started/quick_run.html) in LiBai, run training command in LiBai **root** dir
```shell
# cd path to libai
bash tools/train.sh tools/train_net.py projects/Couplets/configs/config.py 4
```
- After finish training, you will get trained model in path `output/couplet`
```shell
output/couplet
├── config.yaml
├── last_checkpoint
├── log.txt
├── log.txt.rank1
├── log.txt.rank2
├── log.txt.rank3
├── metrics.json
├── model_0004999
├── model_0009999
├── model_0014999
├── model_0019999
├── model_0024999
└── model_final
```
## inference
- for inference in one gpu:
```
# modify path in projects/Couplets/infer.py
# config_file = "output/couplet/config.yaml"
# checkpoint_file = "output/couplet/model_final"
# vocab_file = "data_test/couplets/vocab.txt"
python projects/Couplets/infer.py
```
- for distributed inference:
set your data path and model in `projects/Couplets/distribute_infer.py`
```python
# line 46
self.cfg.vocab_file = "data_test/couplets/vocab.txt"
# line 97 ~106
pipeline = CoupletPipeline(
# you can also use path output/couplet/config.yaml to replace config.py
"projects/Couplets/configs/config.py",
data_parallel=1,
tensor_parallel=1, # modify it according to your needs
pipeline_parallel=4, # modify it according to your needs
pipeline_stage_id=None, # modify it according to your needs
pipeline_num_layers=12, # modify it according to your needs
model_path="output/couplet/model_final/model", # modify it according to your needs
mode="libai",
)
```
```
bash tools/infer.sh projects/Couplets/distribute_infer.py 4
```
## Results
```shell
上联:
天朗气清风和畅
下联:
水流花海月圆融
上联:
千秋月色君长看
下联:
一夜风流人在天
上联:
梦里不知身是客
下联:
此间何处是家乡
```
\ No newline at end of file
import collections
def load_vocab(vocab_file):
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
token = token.strip("\n")
if not token:
break
vocab[token] = index
index += 1
return vocab
class CoupletsTokenizer:
def __init__(self, vocab_file):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.vocab_tokens = self.vocab.keys()
self.unk_id = self.vocab["<unk>"]
self.pad_id = self.vocab["<pad>"]
self.bos_id = self.vocab["<bos>"]
self.eos_id = self.vocab["<eos>"]
def tokenize(self, text):
tokens_list = text.split()
return tokens_list
def convert_tokens_to_ids(self, tokens_list):
ids_list = []
for token in tokens_list:
if token not in self.vocab_tokens:
token = "<unk>"
token_id = self.vocab[token]
ids_list.append(token_id)
return ids_list
def convert_ids_to_tokens(self, ids_list):
tokens_list = []
for token_id in ids_list:
token = self.inv_vocab[token_id]
tokens_list.append(token)
return tokens_list
from libai.config import LazyCall
from configs.common.train import train
from dalle2.models import DiffusionPrior, DiffusionPriorNetwork, Unet, Decoder, DALLE2
from dalle2._clip import OpenAIClipAdapter
from omegaconf import DictConfig
clip = LazyCall(OpenAIClipAdapter)(name="")
swinir = DictConfig({"swinir_path": None})
prior = LazyCall(DiffusionPrior)(
net=LazyCall(DiffusionPriorNetwork)(
dim=768,
depth=24,
num_timesteps=1000,
max_text_len=77,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
dim_head=64,
heads=32,
ff_mult=4,
attn_dropout=0.05,
ff_dropout=0.05,
normformer=True,
),
clip=clip,
image_embed_dim=768,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
condition_on_text_encodings=True,
)
unet1 = LazyCall(Unet)(
dim=320,
image_embed_dim=768,
text_embed_dim=768,
cond_dim=512,
channels=3,
dim_mults=(1, 2, 3, 4),
num_resnet_blocks=4,
attn_heads=8,
attn_dim_head=64,
sparse_attn=True,
memory_efficient=True,
cond_on_text_encodings=True,
self_attn=[False, True, True, True],
)
decoder = LazyCall(Decoder)(
unet=(unet1,),
image_sizes=[
64,
],
clip=None,
channels=3,
timesteps=1000,
loss_type="l2",
beta_schedule=["cosine"],
learned_variance=True,
)
model = LazyCall(DALLE2)(
prior=prior,
decoder=decoder,
prior_weight_path="",
decoder_weight_path="",
)
from .models import Unet, DALLE2, DiffusionPriorNetwork, DiffusionPrior, Decoder
from ._clip import OpenAIClipAdapter, import_flow_clip
import os
import sys
from collections import namedtuple
import oneflow as flow
from oneflow import nn
from .models import l2norm
def import_flow_clip(fn):
def wrapper(*args, **kwargs):
sys.path.append(
os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")), "CLIP")
)
fn(*args, **kwargs)
sys.path.pop()
return wrapper
EmbeddedText = namedtuple("EmbedTextReturn", ["text_embed", "text_encodings"])
EmbeddedImage = namedtuple("EmbedImageReturn", ["image_embed", "image_encodings"])
class BaseClipAdapter(nn.Module):
def __init__(self, clip, **kwargs):
super().__init__()
self.clip = clip
self.overrides = kwargs
@property
def dim_latent(self):
raise NotImplementedError
@property
def image_size(self):
raise NotImplementedError
@property
def image_channels(self):
raise NotImplementedError
@property
def max_text_len(self):
raise NotImplementedError
def embed_text(self, text):
raise NotImplementedError
def embed_image(self, image):
raise NotImplementedError
class OpenAIClipAdapter(BaseClipAdapter):
@import_flow_clip
def __init__(self, name="ViT-L/14"):
import clip
openai_clip, preprocess = clip.load(name)
super().__init__(openai_clip)
self.eos_id = 49407 # for handling 0 being also '!'
text_attention_final = self.find_layer("ln_final")
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = preprocess.transforms[-1]
self.cleared = False
def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)
def clear(self):
if self.cleared:
return
self.handle()
def _hook(self, _, inputs, outputs):
self.text_encodings = outputs
@property
def dim_latent(self):
return 512
@property
def image_size(self):
return self.clip.visual.input_resolution
@property
def image_channels(self):
return 3
@property
def max_text_len(self):
return self.clip.context_length
@flow.no_grad()
def embed_text(self, text):
text = text[..., : self.max_text_len]
assert not self.cleared
text_mask = text != 0 # v0.15.4
text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return l2norm(text_embed.float()), text_encodings.float(), text_mask
@flow.no_grad()
def embed_image(self, image):
assert not self.cleared
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
import logging
import oneflow as flow
from oneflow.framework.check_point_v2 import _broadcast_py_object
import libai.utils.distributed as dist
from libai.models.build import build_model
from libai.models.utils.model_loader.base_loader import (
ModelLoaderHuggerFace,
_load_state_dict_into_model,
)
logger = logging.getLogger("libai.dalle2." + __name__)
class Dalle2ModelLoader(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_1 = ""
self.base_model_prefix_2 = ""
def _convert_state_dict(self, state_dict, module="prior"):
old_keys = []
new_keys = []
if module == "prior":
for k in state_dict.keys():
if "clip." in k:
continue
old_keys.append(k)
if k.endswith(".g"):
k = k[:-1] + "weight"
elif k.startswith("net.causal_transformer"):
if k.endswith("gamma"):
k = k[:-5] + "weight"
elif k.endswith("beta"):
k = k[:-4] + "bias"
new_keys.append("prior." + k)
elif module == "decoder":
for k in state_dict.keys():
if "clip." in k:
continue
old_keys.append(k)
if k.endswith(".g"):
k = k[:-1] + "weight"
elif "cross_attn" in k:
if k.endswith("gamma"):
k = k[:-5] + "weight"
elif k.endswith("beta"):
k = k[:-4] + "bias"
new_keys.append("decoder." + k)
ret_state_dict = {}
for old_key, new_key in zip(old_keys, new_keys):
ret_state_dict[new_key] = state_dict.pop(old_key)
return ret_state_dict
def load(self):
if dist.is_main_process():
# prior
logger.info("loading torch model prior...")
torch_state_dict = self._load_torch_state_dict(self.libai_cfg.model.prior_weight_path)[
"ema_model"
]
logger.info("converting torch model prior into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
prior_state_dict = self._convert_state_dict(flow_state_dict)
# decoder
logger.info("loading torch model decoder...")
torch_state_dict = self._load_torch_state_dict(self.libai_cfg.model.decoder_weight_path)
flow_state_dict = self._convert_tensors(torch_state_dict)
logger.info("converting torch model decoder into oneflow model...")
decoder_state_dict = self._convert_state_dict(flow_state_dict, module="decoder")
flow_state_dict = {**prior_state_dict, **decoder_state_dict}
else:
flow_state_dict = None
logger.info("building LiBai model...")
self.libai_cfg = _broadcast_py_object(self.libai_cfg, src=0)
self.model = build_model(self.model)
self.model._apply(dist.convert_to_distributed_default_setting)
self.model = self.model.eval()
flow.cuda.empty_cache()
# State_dict to global
logger.info("transfering state_dict local to global...")
flow_state_dict = self._state_dict_to_global(flow_state_dict, mode="pytorch") # oom
# Load
# (
# model,
# missing_keys,
# unexpected_keys,
# mismatched_keys,
# error_msgs,
# ) = self._load_pretrained_model(self.model, flow_state_dict, self.pretrained_model_path)
logger.info("loading model weights into LiBai...")
_load_state_dict_into_model(self.model, flow_state_dict, "")
return self.model
# From https://github.com/arogozhnikov/einops/blob/master/einops/layers/oneflow.py
import re
from functools import wraps
import oneflow as flow
from einops import rearrange, reduce, repeat
from einops._backends import AbstractBackend
from einops.layers import RearrangeMixin
from oneflow import nn
class Rearrange(RearrangeMixin, flow.nn.Module):
def forward(self, input):
return self._apply_recipe(input)
class OneFlowBackend(AbstractBackend):
framework_name = "oneflow"
def __init__(self):
import oneflow as flow
self.flow = flow
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.flow.Tensor)
def from_numpy(self, x):
variable = self.flow.from_numpy(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
return x.detach().cpu().numpy()
def arange(self, start, stop):
return self.flow.arange(start, stop, dtype=self.flow.int64)
def reduce(self, x, operation, reduced_axes):
for axis in sorted(reduced_axes, reverse=True):
if operation == "min":
x, _ = x.min(dim=axis)
elif operation == "max":
x, _ = x.max(dim=axis)
elif operation in ["sum", "mean", "prod"]:
x = getattr(x, operation)(dim=axis)
else:
raise NotImplementedError("Unknown reduction ", operation)
return x
def transpose(self, x, axes):
return x.permute(axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.flow.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(*repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def add_axis(self, x, new_position):
return self.flow.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]
def einsum(self, pattern, *x):
return self.flow.einsum(pattern, *x)
# From https://github.com/lucidrains/einops-exts/tree/main/einops_exts
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops, to_einops, fn):
super().__init__()
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
def forward(self, x, **kwargs):
shape = x.shape
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(" "), shape)))
x = rearrange(x, f"{self.from_einops} -> {self.to_einops}")
x = self.fn(x, **kwargs)
x = rearrange(x, f"{self.to_einops} -> {self.from_einops}", **reconstitute_kwargs)
return x
# checking shape
# @nils-werner
# https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838
def check_shape(tensor, pattern, **kwargs):
return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)
# do same einops operations on a list of tensors
def _many(fn):
@wraps(fn)
def inner(tensors, pattern, **kwargs):
return (fn(tensor, pattern, **kwargs) for tensor in tensors)
return inner
# do einops with unflattening of anonymously named dimensions
# (...flattened) -> ...flattened
def _with_anon_dims(fn):
@wraps(fn)
def inner(tensor, pattern, **kwargs):
regex = r"(\.\.\.[a-zA-Z]+)"
matches = re.findall(regex, pattern)
def get_anon_dim_name(t):
return t.lstrip("...")
dim_prefixes = tuple(map(get_anon_dim_name, set(matches)))
update_kwargs_dict = dict()
for prefix in dim_prefixes:
assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
dim_list = kwargs[prefix]
assert isinstance(
dim_list, (list, tuple)
), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
dim_names = list(map(lambda ind: f"{prefix}{ind}", range(len(dim_list))))
update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))
def sub_with_anonymous_dims(t):
dim_name_prefix = get_anon_dim_name(t.groups()[0])
return " ".join(update_kwargs_dict[dim_name_prefix].keys())
pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)
for prefix, update_dict in update_kwargs_dict.items():
del kwargs[prefix]
kwargs.update(update_dict)
return fn(tensor, pattern_new, **kwargs)
return inner
# generate all helper functions
rearrange_many = _many(rearrange)
repeat_many = _many(repeat)
reduce_many = _many(reduce)
rearrange_with_anon_dims = _with_anon_dims(rearrange)
repeat_with_anon_dims = _with_anon_dims(repeat)
reduce_with_anon_dims = _with_anon_dims(reduce)
import logging
import os
from libai.utils.file_utils import download_file
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
url_map = {
"prior": "https://huggingface.co/nousr/conditioned-prior/resolve/main/vit-l-14/prior_aes_finetune.pth", # noqa
"decoder": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/latest.pth", # noqa
"clip": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", # noqa
"bpe_vocab": "https://oneflow-static.oss-cn-beijing.aliyuncs.com/libai/clip/bpe_simple_vocab_16e6.txt.gz", # noqa
"swinir": "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/",
}
def _download_if_not_exist(path, name):
if os.path.exists(path):
logger.info((f"using {name}'s weight at {path}"))
return
if name == "swinir":
os.makedirs(os.path.dirname(path), exist_ok=True)
download_file(path, url_map[name] + os.path.basename(path))
return
os.makedirs(os.path.dirname(path), exist_ok=True)
download_file(path, url_map[name])
def download_dalle2_weights(cfg):
if not os.path.exists("./dalle2/data/bpe_simple_vocab_16e6.txt.gz"):
os.makedirs("./dalle2/data", exist_ok=True)
download_file("./dalle2/data/bpe_simple_vocab_16e6.txt.gz", url_map["bpe_vocab"])
_download_if_not_exist(cfg.swinir.swinir_path, "swinir")
_download_if_not_exist(cfg.model.prior_weight_path, "prior")
_download_if_not_exist(cfg.model.decoder_weight_path, "decoder")
_download_if_not_exist(cfg.model.prior.clip.name, "clip")
import math
import random
from contextlib import contextmanager
from functools import partial, wraps
import flowvision.transforms as T
import kornia.augmentation as K
import numpy as np
import oneflow as flow
import oneflow.nn.functional as F
from einops import rearrange, reduce, repeat
from kornia.filters import gaussian_blur2d
from omegaconf import ListConfig
from oneflow import einsum, nn
from oneflow.nn import Conv2d, ConvTranspose2d, GroupNorm
from oneflow.nn.functional import layer_norm
from resize_right import resize
from tqdm.auto import tqdm
from libai.layers import Embedding, LayerNorm, Linear
from libai.utils import distributed as dist
from .einops_exts import EinopsToAndFrom, Rearrange, check_shape, rearrange_many, repeat_many
from .rotary_embedding_flow import RotaryEmbedding
from .tokenizer import SimpleTokenizer
from .vqgan_vae import NullVQGanVAE, VQGanVAE
# rotary embeddings
# constants
def get_default_sbp():
return dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
def get_default_placement():
return dist.get_layer_placement(0)
NAT = 1.0 / math.log(2.0)
random.seed(666)
np.random.seed(6666)
# helper functions
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def first(arr, d=None):
if len(arr) == 0:
return d
return arr[0]
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length=None):
if isinstance(val, list) or isinstance(val, ListConfig):
val = tuple(val)
out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(out) == length
return out
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
@contextmanager
def null_context(*args, **kwargs):
yield
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length, fillvalue=None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
def freeze_all_layers_(module):
set_module_requires_grad_(module, False)
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# tensor helpers
def log(t, eps=1e-12):
return flow.log(t.clamp(min=eps))
def l2norm(t):
return F.normalize(t, dim=-1)
def resize_image_to(image, target_image_size):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
scale_factors = target_image_size / orig_image_size
return resize(image, scale_factors=scale_factors)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
# but CLIP may otherwise
def normalize_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_zero_to_one(normed_img):
return (normed_img + 1) * 0.5
# classifier free guidance functions
def prob_mask_like(shape, prob, placement=None, sbp=None):
placement = placement or get_default_placement()
sbp = sbp or get_default_sbp()
if prob == 1:
return flow.ones(shape, dtype=flow.bool, placement=placement, sbp=sbp)
elif prob == 0:
return flow.zeros(shape, dtype=flow.bool, placement=placement, sbp=sbp)
else:
return flow.zeros(shape, placement=placement, sbp=sbp).float().uniform_(0, 1) < prob
# gaussian diffusion helper functions
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(len(a.shape) - 1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def meanflat(x):
return x.mean(dim=tuple(range(1, len(x.shape))))
def normal_kl(mean1, logvar1, mean2, logvar2):
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ flow.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * flow.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
return 0.5 * (1.0 + flow.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres=0.999):
assert x.shape == means.shape == log_scales.shape
# attempting to correct nan gradients when learned variance is turned on
# in the setting of deepspeed fp16
eps = 1e-12 if x.dtype == flow.float32 else 1e-3
centered_x = x - means
inv_stdv = flow.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = log(cdf_plus)
log_one_minus_cdf_min = log(1.0 - cdf_min)
cdf_delta = cdf_plus - cdf_min
log_probs = flow.where(
x < -thres, log_cdf_plus, flow.where(x > thres, log_one_minus_cdf_min, log(cdf_delta, eps))
)
return log_probs
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = flow.linspace(0, timesteps, steps, dtype=flow.float64)
alphas_cumprod = flow.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return flow.clip(betas, 0, 0.999)
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return flow.linspace(beta_start, beta_end, timesteps, dtype=flow.float64)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return flow.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps, dtype=flow.float64) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = flow.linspace(-6, 6, timesteps, dtype=flow.float64)
return flow.sigmoid(betas) * (beta_end - beta_start) + beta_start
class NoiseScheduler(nn.Module):
def __init__(
self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma=0.0, p2_loss_weight_k=1
):
super().__init__()
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / flow.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
betas = betas
alphas = 1.0 - betas
alphas_cumprod = flow.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
if loss_type == "l1":
loss_fn = F.l1_loss
elif loss_type == "l2":
loss_fn = flow.nn.MSELoss()
elif loss_type == "huber":
loss_fn = F.smooth_l1_loss
else:
raise NotImplementedError()
self.loss_type = loss_type
self.loss_fn = loss_fn
# register buffer helper function to cast double back to float
def register_buffer(name, val):
self.register_buffer(name, val.to(flow.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", flow.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", flow.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", flow.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", flow.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", flow.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
register_buffer(
"posterior_log_variance_clipped", flow.log(posterior_variance.clamp(min=1e-20))
)
register_buffer(
"posterior_mean_coef1", betas * flow.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
register_buffer(
"posterior_mean_coef2",
(1.0 - alphas_cumprod_prev) * flow.sqrt(alphas) / (1.0 - alphas_cumprod),
)
# p2 loss reweighting
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.0
register_buffer(
"p2_loss_weight",
(p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma,
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: flow.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
return loss * extract(self.p2_loss_weight, times, loss.shape)
# diffusion prior
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(flow.ones(1, dim, 1, 1))
def forward(self, x):
# var = flow.var(x, dim = 1, unbiased = False, keepdim = True)
# mean = flow.mean(x, dim = 1, keepdim = True)
# return (x - mean) / (var + self.eps).sqrt() * self.weight
x = x.permute(0, 2, 3, 1)
out = layer_norm(x, normalized_shape=(x.shape[-1:]), eps=self.eps)
return out.permute(0, 3, 1, 2) * self.weight
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor=2.0,
depth=2,
norm=False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
def norm_fn():
return LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(Linear(dim_in, hidden_dim), nn.SiLU(), norm_fn())]
for _ in range(depth - 1):
layers.append(nn.Sequential(Linear(hidden_dim, hidden_dim), nn.SiLU(), norm_fn()))
layers.append(Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# relative positional bias for causal transformer
class RelPosBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
n = -relative_position
n = flow.max(n, flow.zeros_like(n)).long()
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = (
max_exact
+ (
flow.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
val_if_large = flow.min(val_if_large, flow.zeros_like(val_if_large) + num_buckets - 1)
return flow.where(is_small, n, val_if_large)
def forward(self, i, j):
q_pos = flow.arange(i, dtype=flow.long)
k_pos = flow.arange(j, dtype=flow.long)
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
rp_bucket = self._relative_position_bucket(
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
)
values = self.relative_attention_bias(
rp_bucket.to_global(sbp=get_default_sbp(), placement=get_default_placement())
)
return rearrange(values, "i j h -> h i j")
# feedforward
class SwiGLU(nn.Module):
"""used successfully in https://arxiv.org/abs/2204.0231"""
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return x * F.silu(gate)
def FeedForward(dim, mult=4, dropout=0.0, post_activation_norm=False):
"""post-activation norm https://arxiv.org/abs/2110.09456"""
inner_dim = int(mult * dim)
return nn.Sequential(
LayerNorm(dim),
Linear(dim, inner_dim * 2, bias=False, parallel="col"),
SwiGLU(),
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
Linear(inner_dim, dim, bias=False, parallel="row"),
)
# attention
class Attention(nn.Module):
def __init__(self, dim, *, dim_head=64, heads=8, dropout=0.0, causal=False, rotary_emb=None):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.causal = causal
self.norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(flow.randn(2, dim_head))
self.to_q = Linear(dim, inner_dim, bias=False, parallel="col")
self.to_kv = Linear(dim, dim_head * 2, bias=False, parallel="col")
self.rotary_emb = rotary_emb
self.to_out = nn.Sequential(
Linear(inner_dim, dim, bias=False, parallel="row"), LayerNorm(dim)
)
def forward(self, x, mask=None, attn_bias=None):
b, n = x.shape[:2]
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
nk, nv = repeat_many(self.null_kv.unbind(dim=-2), "d -> b 1 d", b=b)
k = flow.cat((nk, k), dim=-2)
v = flow.cat((nv, v), dim=-2)
# calculate query / key similarities
sim = einsum("b h i d, b j d -> b h i j", q, k)
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -3.4028e38 # flow.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value=1)
mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(1 - mask, max_neg_value) # ~mask
if self.causal:
i, j = sim.shape[-2:]
causal_mask = flow.ones(
(i, j), placement=get_default_placement(), sbp=get_default_sbp(), dtype=flow.int32
).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
# attention
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
ff_mult=4,
norm_out=True,
attn_dropout=0.0,
ff_dropout=0.0,
final_proj=True,
normformer=False,
rotary_emb=True,
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads=heads)
rotary_emb = RotaryEmbedding(dim=min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim=dim,
causal=True,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_emb=rotary_emb,
),
FeedForward(
dim=dim,
mult=ff_mult,
dropout=ff_dropout,
post_activation_norm=normformer,
),
]
)
)
self.norm = LayerNorm(dim) if norm_out else nn.Identity()
self.project_out = Linear(dim, dim, bias=False) if final_proj else nn.Identity()
def forward(
self,
x,
mask=None,
):
n = x.shape[1]
attn_bias = self.rel_pos_bias(n, n + 1)
for attn, ff in self.layers:
x = attn(x, mask=mask, attn_bias=attn_bias) + x
x = ff(x) + x
out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module):
def __init__(
self,
dim,
num_timesteps=None,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
max_text_len=256,
**kwargs,
):
super().__init__()
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
self.to_text_embeds = nn.Sequential(
Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
Rearrange("b (n d) -> b n d", n=num_text_embeds),
)
self.to_time_embeds = nn.Sequential(
Embedding(num_timesteps, dim * num_time_embeds)
if exists(num_timesteps)
else nn.Sequential(
SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)
), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange("b (n d) -> b n d", n=num_time_embeds),
)
self.to_image_embeds = nn.Sequential(
Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
Rearrange("b (n d) -> b n d", n=num_image_embeds),
)
self.learned_query = nn.Parameter(flow.randn(dim))
self.causal_transformer = CausalTransformer(dim=dim, **kwargs)
def forward_with_cond_scale(self, *args, cond_scale=1.0, **kwargs):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob=1.0, **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
image_embed,
diffusion_timesteps,
*,
text_embed,
text_encodings=None,
mask=None,
cond_drop_prob=0.0,
):
batch, dim, dtype = *image_embed.shape, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = (
self.num_time_embeds,
self.num_image_embeds,
self.num_text_embeds,
)
text_embed = self.to_text_embeds(text_embed)
image_embed = self.to_image_embeds(image_embed)
# make text encodings optional
# although the paper seems to suggest it is present <--
if not exists(text_encodings):
text_encodings = flow.empty(
(batch, 0, dim),
placement=get_default_placement(),
sbp=get_default_sbp(),
dtype=dtype,
)
if not exists(mask):
mask = flow.ones(
(batch, text_encodings.shape[-2]),
placement=get_default_placement(),
sbp=get_default_sbp(),
dtype=flow.bool,
)
# classifier free guidance
keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob)
keep_mask = rearrange(keep_mask, "b -> b 1").to_global(
placement=get_default_placement(), sbp=get_default_sbp()
)
mask &= keep_mask
keep_mask = repeat(keep_mask, "b 1 -> b n", n=num_text_embeds)
mask = flow.cat((mask, keep_mask), dim=1)
if exists(mask):
attend_padding = (
1 + num_time_embeds + num_image_embeds
) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask.to(flow.int32), (0, attend_padding), value=1)
time_embed = self.to_time_embeds(
diffusion_timesteps.to_global(placement=get_default_placement(), sbp=get_default_sbp())
)
learned_queries = repeat(self.learned_query, "d -> b 1 d", b=batch)
tokens = flow.cat(
(text_encodings, text_embed, time_embed, image_embed, learned_queries), dim=-2
)
# attend
tokens = self.causal_transformer(tokens, mask=mask)
# get learned query, which should predict the image embedding (per DDPM timestep)
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
*,
clip=None,
image_embed_dim=None,
image_size=None,
image_channels=3,
timesteps=1000,
cond_drop_prob=0.0,
loss_type="l2",
predict_x_start=True,
beta_schedule="cosine",
condition_on_text_encodings=True,
sampling_clamp_l2norm=False,
training_clamp_l2norm=False,
init_image_embed_l2norm=False,
image_embed_scale=None,
clip_adapter_overrides=dict(),
):
super().__init__()
self.noise_scheduler = NoiseScheduler(
beta_schedule=beta_schedule, timesteps=timesteps, loss_type=loss_type
)
if exists(clip):
assert (
image_channels == clip.image_channels
), f"channels of image ({image_channels}) should be equal to the channels "
"that CLIP accepts ({clip.image_channels})"
freeze_model_and_make_eval_(clip)
self.clip = clip
else:
assert exists(
image_embed_dim
), "latent dimension must be given, if training prior network without CLIP given"
self.clip = None
self.net = net
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
self.channels = default(image_channels, lambda: clip.image_channels)
self.cond_drop_prob = cond_drop_prob
self.can_classifier_guidance = cond_drop_prob > 0.0
self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer("_dummy", flow.tensor([True]), persistent=False)
def p_mean_variance(self, x, t, text_cond, clip_denoised=False, cond_scale=1.0):
assert not (
cond_scale != 1.0 and not self.can_classifier_guidance
), "the model was not trained with conditional dropout, "
"and thus one cannot use classifier free guidance (cond_scale anything other than 1)"
pred = self.net.forward_with_cond_scale(x, t, cond_scale=cond_scale, **text_cond)
if self.predict_x_start:
x_recon = pred
else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1.0, 1.0)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_recon = l2norm(x_recon) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
@flow.no_grad()
def p_sample(self, x, t, text_cond=None, clip_denoised=True, cond_scale=1.0):
b = x.shape[0]
model_mean, _, model_log_variance = self.p_mean_variance(
x=x, t=t, text_cond=text_cond, clip_denoised=clip_denoised, cond_scale=cond_scale
)
noise = flow.randn(*x.shape, placement=get_default_placement(), sbp=get_default_sbp())
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@flow.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale=1.0):
b = shape[0]
image_embed = flow.randn(*shape, placement=get_default_placement(), sbp=get_default_sbp())
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(
reversed(range(0, self.noise_scheduler.num_timesteps)),
desc="sampling loop time step",
total=self.noise_scheduler.num_timesteps,
):
times = flow.full(
(b,), i, placement=get_default_placement(), sbp=get_default_sbp(), dtype=flow.long
)
image_embed = self.p_sample(
image_embed, times, text_cond=text_cond, cond_scale=cond_scale
)
return image_embed
def p_losses(self, image_embed, times, text_cond, noise=None):
noise = default(
noise,
lambda: flow.randn(
*image_embed.shape, placement=get_default_placement(), sbp=get_default_sbp()
),
)
image_embed_noisy = self.noise_scheduler.q_sample(x_start=image_embed, t=times, noise=noise)
pred = self.net(image_embed_noisy, times, cond_drop_prob=self.cond_drop_prob, **text_cond)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
target = noise if not self.predict_x_start else image_embed
loss = self.noise_scheduler.loss_fn(pred, target)
return loss
@flow.no_grad()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond, cond_scale=1.0):
shape = (batch_size, self.image_embed_dim)
img = flow.randn(*shape, placement=get_default_placement(), sbp=get_default_sbp())
for i in tqdm(
reversed(range(0, self.noise_scheduler.num_timesteps)),
desc="sampling loop time step",
total=self.noise_scheduler.num_timesteps,
):
img = self.p_sample(
img,
flow.full(
(batch_size,),
i,
placement=get_default_placement(),
sbp=get_default_sbp(),
dtype=flow.long,
),
text_cond=text_cond,
cond_scale=cond_scale,
)
return img
@flow.no_grad()
@eval_decorator
def sample(
self,
text,
num_samples_per_batch=2,
cond_scale=1.0,
text_embed=None,
text_encodings=None,
text_mask=None,
):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, "b ... -> (b r) ...", r=num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
if text_embed is None:
assert self.clip is not None
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed=text_embed)
if self.condition_on_text_encodings:
text_cond = {**text_cond, "text_encodings": text_encodings, "mask": text_mask}
image_embeds = self.p_sample_loop(
(batch_size, image_embed_dim), text_cond=text_cond, cond_scale=cond_scale
)
# retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond["text_embed"]
text_embeds = rearrange(text_embeds, "(b r) d -> b r d", r=num_samples_per_batch)
image_embeds = rearrange(image_embeds, "(b r) d -> b r d", r=num_samples_per_batch)
text_image_sims = einsum("b r d, b r d -> b r", l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k=1)[1] # .indices
top_sim_indices = repeat(top_sim_indices, "b 1 -> b 1 d", d=image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, "b 1 d -> b d")
def forward(
self,
text=None,
image=None,
text_embed=None, # allow for training on preprocessed CLIP text and image embeddings
image_embed=None,
text_encodings=None, # as well as CLIP text encodings
text_mask=None,
*args,
**kwargs,
):
assert exists(text) ^ exists(text_embed), "either text or text embedding must be supplied"
assert exists(image) ^ exists(image_embed), "either text or text embedding must be supplied"
assert not (
self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))
), "text encodings must be present if you specified to condition on it on initialization"
if exists(image):
image_embed, _ = self.clip.embed_image(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings, text_mask = self.clip.embed_text(text)
text_cond = dict(text_embed=text_embed)
if self.condition_on_text_encodings:
assert exists(
text_encodings
), "text encodings must be present for diffusion prior if specified"
text_cond = {**text_cond, "text_encodings": text_encodings, "mask": text_mask}
# timestep conditioning from ddpm
batch = image_embed.shape[0]
times = flow.randint(
0,
self.noise_scheduler.num_timesteps,
(batch,),
placement=get_default_placement(),
sbp=get_default_sbp(),
dtype=flow.long,
)
# scale image embed (Katherine)
image_embed *= self.image_embed_scale
# calculate forward loss
return self.p_losses(image_embed, times, text_cond=text_cond, *args, **kwargs)
# decoder
def ConvTransposeUpsample(dim, dim_out=None):
dim_out = default(dim_out, dim)
return ConvTranspose2d(dim, dim_out, 4, 2, 1)
def NearestUpsample(dim, dim_out=None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"), Conv2d(dim, dim_out, 3, padding=1)
)
def Downsample(dim, *, dim_out=None):
dim_out = default(dim_out, dim)
return Conv2d(dim, dim_out, 4, 2, 1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = flow.exp(
flow.arange(half_dim, placement=get_default_placement(), sbp=get_default_sbp()) * -emb
)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return flow.cat((emb.sin(), emb.cos()), dim=-1)
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.project = Conv2d(dim, dim_out, 3, padding=1)
self.norm = GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.project(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, cond_dim=None, time_cond_dim=None, groups=8):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(nn.SiLU(), Linear(time_cond_dim, dim_out * 2))
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
"b c h w", "b (h w) c", CrossAttention(dim=dim_out, context_dim=cond_dim)
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None, cond=None):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, "b c -> b c 1 1")
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context=cond) + h
h = self.block2(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self, dim, *, context_dim=None, dim_head=64, heads=8, dropout=0.0, norm_context=False
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(
flow.randn(2, dim_head, placement=get_default_placement(), sbp=get_default_sbp())
)
self.to_q = Linear(dim, inner_dim, bias=False, parallel="col")
self.to_kv = Linear(context_dim, inner_dim * 2, bias=False, parallel="col")
self.to_out = nn.Sequential(
Linear(inner_dim, dim, bias=False, parallel="row"), LayerNorm(dim)
)
def forward(self, x, context, mask=None):
b, n = x.shape[:2]
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim=-2), "d -> b h 1 d", h=self.heads, b=b)
k = flow.cat((nk, k), dim=-2)
v = flow.cat((nv, v), dim=-2)
q = q * self.scale
sim = einsum("b h i d, b h j d -> b h i j", q, k)
max_neg_value = -3.4028e38 # -flow.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value=1)
mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(1 - mask, max_neg_value)
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(self, dim, dim_head=32, heads=8):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.GELU()
self.to_qkv = Conv2d(dim, inner_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(Conv2d(inner_dim, dim, 1, bias=False), ChanLayerNorm(dim))
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim=1)
q, k, v = rearrange_many((q, k, v), "b (h c) x y -> (b h) (x y) c", h=h)
q = q.softmax(dim=-1)
k = k.softmax(dim=-2)
q = q * self.scale
context = einsum("b n d, b n e -> b d e", k, v)
out = einsum("b n d, b d e -> b n e", q, context)
out = rearrange(out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
out = self.nonlin(out)
return self.to_out(out)
class CrossEmbedLayer(nn.Module):
def __init__(self, dim_in, kernel_sizes, dim_out=None, stride=2):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(
Conv2d(dim_in, dim_scale, kernel, stride=stride, padding=(kernel - stride) // 2)
)
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return flow.cat(fmaps, dim=1)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim=None,
text_embed_dim=None,
cond_dim=None,
num_image_tokens=4,
num_time_tokens=2,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
channels_out=None,
self_attn=False,
attn_dim_head=32,
attn_heads=16,
lowres_cond=False, #
sparse_attn=False,
attend_at_middle=True,
cond_on_text_encodings=False,
max_text_len=256,
cond_on_image_embeds=False,
add_image_embeds_to_time=True, #
init_dim=None,
init_conv_kernel_size=7,
resnet_groups=8,
num_resnet_blocks=2,
init_cross_embed_kernel_sizes=(3, 7, 15),
cross_embed_downsample=False,
cross_embed_downsample_kernel_sizes=(2, 4),
memory_efficient=False,
scale_skip_connection=False,
nearest_upsample=False,
final_conv_kernel_size=1,
**kwargs,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals["self"]
del self._locals["__class__"]
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
init_channels = channels if not lowres_cond else channels * 2
init_dim = default(init_dim, dim)
self.init_conv = CrossEmbedLayer(
init_channels, dim_out=init_dim, kernel_sizes=init_cross_embed_kernel_sizes, stride=1
)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
num_stages = len(in_out)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4
self.to_time_hiddens = nn.Sequential(
SinusoidalPosEmb(dim), Linear(dim, time_cond_dim, parallel="col"), nn.GELU()
)
self.to_time_tokens = nn.Sequential(
Linear(time_cond_dim, cond_dim * num_time_tokens, parallel="row"),
Rearrange("b (r d) -> b r d", r=num_time_tokens),
)
self.to_time_cond = nn.Sequential(Linear(time_cond_dim, time_cond_dim, parallel="row"))
self.image_to_tokens = (
nn.Sequential(
Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange("b (n d) -> b n d", n=num_image_tokens),
)
if cond_on_image_embeds and image_embed_dim != cond_dim
else nn.Identity()
)
self.to_image_hiddens = (
nn.Sequential(Linear(image_embed_dim, time_cond_dim), nn.GELU())
if cond_on_image_embeds and add_image_embeds_to_time
else None
)
self.norm_cond = LayerNorm(cond_dim)
self.norm_mid_cond = LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text_encodings:
assert exists(
text_embed_dim
), "text_embed_dim must be given to the unet if cond_on_text_encodings is True"
self.text_to_cond = Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(flow.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(flow.randn(1, time_cond_dim))
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(flow.randn(1, max_text_len, cond_dim))
# whether to scale skip connection, adopted in Imagen
self.skip_connect_scale = 1.0 if not scale_skip_connection else (2 ** -0.5)
# attention related params
attn_kwargs = dict(heads=attn_heads, dim_head=attn_dim_head)
self_attn = cast_tuple(self_attn, num_stages)
def create_self_attn(dim):
return EinopsToAndFrom("b c h w", "b (h w) c", Residual(Attention(dim, **attn_kwargs)))
# resnet block klass
resnet_groups = cast_tuple(resnet_groups, num_stages)
top_level_resnet_group = first(resnet_groups)
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(
CrossEmbedLayer, kernel_sizes=cross_embed_downsample_kernel_sizes
)
# upsample klass
upsample_klass = ConvTransposeUpsample if not nearest_upsample else NearestUpsample
# give memory efficient unet an initial resnet block
self.init_resnet_block = (
ResnetBlock(
init_dim, init_dim, time_cond_dim=time_cond_dim, groups=top_level_resnet_group
)
if memory_efficient
else None
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
skip_connect_dims = [] # keeping track of skip connection dimensions
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(
zip(in_out, resnet_groups, num_resnet_blocks, self_attn)
):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
dim_layer = dim_out if memory_efficient else dim_in
skip_connect_dims.append(dim_layer)
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_layer)
elif sparse_attn:
attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
self.downs.append(
nn.ModuleList(
[
downsample_klass(dim_in, dim_out=dim_out) if memory_efficient else None,
ResnetBlock(
dim_layer, dim_layer, time_cond_dim=time_cond_dim, groups=groups
),
nn.ModuleList(
[
ResnetBlock(
dim_layer,
dim_layer,
cond_dim=layer_cond_dim,
time_cond_dim=time_cond_dim,
groups=groups,
)
for _ in range(layer_num_resnet_blocks)
]
),
attention,
downsample_klass(dim_layer, dim_out=dim_out)
if not is_last and not memory_efficient
else Conv2d(dim_layer, dim_out, 1),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(
mid_dim,
mid_dim,
cond_dim=cond_dim,
time_cond_dim=time_cond_dim,
groups=resnet_groups[-1],
)
self.mid_attn = create_self_attn(mid_dim)
self.mid_block2 = ResnetBlock(
mid_dim,
mid_dim,
cond_dim=cond_dim,
time_cond_dim=time_cond_dim,
groups=resnet_groups[-1],
)
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(
zip(
reversed(in_out),
reversed(resnet_groups),
reversed(num_resnet_blocks),
reversed(self_attn),
)
):
is_last = ind >= (len(in_out) - 1)
layer_cond_dim = cond_dim if not is_last else None
skip_connect_dim = skip_connect_dims.pop()
attention = nn.Identity()
if layer_self_attn:
attention = create_self_attn(dim_out)
elif sparse_attn:
attention = Residual(LinearAttention(dim_out, **attn_kwargs))
self.ups.append(
nn.ModuleList(
[
ResnetBlock(
dim_out + skip_connect_dim,
dim_out,
cond_dim=layer_cond_dim,
time_cond_dim=time_cond_dim,
groups=groups,
),
nn.ModuleList(
[
ResnetBlock(
dim_out + skip_connect_dim,
dim_out,
cond_dim=layer_cond_dim,
time_cond_dim=time_cond_dim,
groups=groups,
)
for _ in range(layer_num_resnet_blocks)
]
),
attention,
upsample_klass(dim_out, dim_in)
if not is_last or memory_efficient
else nn.Identity(),
]
)
)
self.final_resnet_block = ResnetBlock(
dim * 2, dim, time_cond_dim=time_cond_dim, groups=top_level_resnet_group
)
self.to_out = Conv2d(
dim,
self.channels_out,
kernel_size=final_conv_kernel_size,
padding=final_conv_kernel_size // 2,
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
self, *, lowres_cond, channels, channels_out, cond_on_image_embeds, cond_on_text_encodings
):
if (
lowres_cond == self.lowres_cond
and channels == self.channels
and cond_on_image_embeds == self.cond_on_image_embeds
and cond_on_text_encodings == self.cond_on_text_encodings
and channels_out == self.channels_out
):
return self
updated_kwargs = dict(
lowres_cond=lowres_cond,
channels=channels,
channels_out=channels_out,
cond_on_image_embeds=cond_on_image_embeds,
cond_on_text_encodings=cond_on_text_encodings,
)
return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale(self, *args, cond_scale=1.0, **kwargs):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(
*args, text_cond_drop_prob=1.0, image_cond_drop_prob=1.0, **kwargs
)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
*,
image_embed,
lowres_cond_img=None,
text_encodings=None,
text_mask=None,
image_cond_drop_prob=0.0,
text_cond_drop_prob=0.0,
blur_sigma=None,
blur_kernel_size=None,
):
batch_size = x.shape[0]
# add low resolution conditioning, if present
assert not (
self.lowres_cond and not exists(lowres_cond_img)
), "low resolution conditioning image must be present"
if exists(lowres_cond_img):
x = flow.cat((x, lowres_cond_img), dim=1)
# initial convolution
x = self.init_conv(x)
r = x.clone() # final residual
# time conditioning
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob)
text_keep_mask = rearrange(text_keep_mask, "b -> b 1 1")
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
image_keep_mask_hidden = rearrange(image_keep_mask, "b -> b 1")
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
image_hiddens = flow.where(image_keep_mask_hidden, image_hiddens, null_image_hiddens)
t = t + image_hiddens
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = None
if self.cond_on_image_embeds:
image_keep_mask_embed = rearrange(image_keep_mask, "b -> b 1 1")
image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(
image_tokens.dtype
) # for some reason pyflow AMP not working
image_tokens = flow.where(image_keep_mask_embed, image_tokens, null_image_embed)
# take care of text encodings (optional)
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, : self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
# text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = F.pad(text_mask.to(flow.int32), (0, remainder), value=0)
text_mask = rearrange(text_mask, "b n -> b n 1")
text_keep_mask = text_mask & text_keep_mask
null_text_embed = self.null_text_embed.to(
text_tokens.dtype
) # for some reason pyflow AMP not working
text_tokens = flow.where(text_keep_mask, text_tokens, null_text_embed)
# main conditioning tokens (c)
c = time_tokens
if exists(image_tokens):
c = flow.cat((c, image_tokens), dim=-2)
# text and image conditioning tokens (mid_c), to save on compute,
# only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_tokens) else flow.cat((c, text_tokens), dim=-2)
# normalize conditioning tokens
c = self.norm_cond(c)
mid_c = self.norm_mid_cond(mid_c)
# initial resnet block
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
# go through the layers of the unet, down and up
hiddens = []
for pre_downsample, init_block, resnet_blocks, attn, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, t, c)
for resnet_block in resnet_blocks:
x = resnet_block(x, t, c)
hiddens.append(x)
x = attn(x)
hiddens.append(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, t, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, t, mid_c)
def connect_skip(fmap):
return flow.cat((fmap, hiddens.pop() * self.skip_connect_scale), dim=1)
for init_block, resnet_blocks, attn, upsample in self.ups:
x = connect_skip(x)
x = init_block(x, t, c)
for resnet_block in resnet_blocks:
x = connect_skip(x)
x = resnet_block(x, t, c)
x = attn(x)
x = upsample(x)
x = flow.cat((x, r), dim=1)
x = self.final_resnet_block(x, t)
return self.to_out(x)
class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first=True,
blur_sigma=0.6,
blur_kernel_size=3,
):
super().__init__()
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
def forward(
self,
cond_fmap,
*,
target_image_size,
downsample_image_size=None,
blur_sigma=None,
blur_kernel_size=None,
):
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
cond_fmap = gaussian_blur2d(
cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)
)
cond_fmap = resize_image_to(cond_fmap, target_image_size)
return cond_fmap
class Decoder(nn.Module):
def __init__(
self,
unet,
*,
clip=None,
image_size=None,
channels=3,
vae=tuple(),
timesteps=1000,
image_cond_drop_prob=0.1,
text_cond_drop_prob=0.5,
loss_type="l2",
beta_schedule=None,
predict_x_start=False,
predict_x_start_for_latent_diffusion=False,
image_sizes=None, # for cascading ddpm, image size at each stage
random_crop_sizes=None,
lowres_downsample_first=True,
blur_sigma=0.6, # cascading ddpm - blur sigma
blur_kernel_size=3, # cascading ddpm - blur kernel size
clip_denoised=True,
clip_x_start=True,
clip_adapter_overrides=dict(),
learned_variance=True,
learned_variance_constrain_frac=False,
vb_loss_weight=0.001,
unconditional=False, # set to True for generating images without conditioning
auto_normalize_img=True,
use_dynamic_thres=False, # from the Imagen paper
dynamic_thres_percentile=0.9,
p2_loss_weight_gamma=0.0,
p2_loss_weight_k=1,
):
super().__init__()
# clip
self.clip = None
if exists(clip):
assert not unconditional, "clip must not be given if doing unconditional image training"
assert (
channels == clip.image_channels
), f"channels of image ({channels}) should be equal to the"
" channels that CLIP accepts ({clip.image_channels})"
freeze_model_and_make_eval_(clip)
self.clip = clip
# determine image size, with image_size and image_sizes taking precedence
if exists(image_size) or exists(image_sizes):
assert exists(image_size) ^ exists(
image_sizes
), "only one of image_size or image_sizes must be given"
image_size = default(image_size, lambda: image_sizes[-1])
elif exists(clip):
image_size = clip.image_size
else:
raise ("either image_size, image_sizes, or clip must be given to decoder")
# channels
self.channels = channels
# verify conditioning method
unets = cast_tuple(unet)
num_unets = len(unets)
self.unconditional = unconditional
# automatically take care of ensuring that first unet is unconditional while the rest
# of the unets are conditioned on the low resolution image produced by previous unet
vaes = pad_tuple_to_length(
cast_tuple(vae), len(unets), fillvalue=NullVQGanVAE(channels=self.channels)
)
# whether to use learned variance, defaults to True for the first unet in the cascade
learned_variance = pad_tuple_to_length(
cast_tuple(learned_variance), len(unets), fillvalue=False
)
self.learned_variance = learned_variance
# whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.learned_variance_constrain_frac = learned_variance_constrain_frac
self.vb_loss_weight = vb_loss_weight
# construct unets and vaes
self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(
zip(unets, vaes, learned_variance)
):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
one_unet = one_unet.cast_model_parameters(
lowres_cond=not is_first,
cond_on_image_embeds=not unconditional and is_first,
cond_on_text_encodings=not unconditional and one_unet.cond_on_text_encodings,
channels=unet_channels,
channels_out=unet_channels_out,
)
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
# create noise schedulers per unet
if not exists(beta_schedule):
beta_schedule = (
"cosine",
*(("cosine",) * max(num_unets - 2, 0)),
*(("linear",) * int(num_unets > 1)),
)
beta_schedule = cast_tuple(beta_schedule, num_unets)
p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(
beta_schedule, p2_loss_weight_gamma
):
noise_scheduler = NoiseScheduler(
beta_schedule=unet_beta_schedule,
timesteps=timesteps,
loss_type=loss_type,
p2_loss_weight_gamma=unet_p2_loss_weight_gamma,
p2_loss_weight_k=p2_loss_weight_k,
)
self.noise_schedulers.append(noise_scheduler)
# unet image sizes
image_sizes = default(image_sizes, (image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(
image_sizes
), "you did not supply the correct number of u-nets "
f"({len(self.unets)}) for resolutions {image_sizes}"
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
# predict x0 config
self.predict_x_start = (
cast_tuple(predict_x_start, len(unets))
if not predict_x_start_for_latent_diffusion
else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
)
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (
False,
*((True,) * (len(self.unets) - 1)),
), "the first unet must be unconditioned (by low resolution image), "
"and the rest of the unets must have `lowres_cond` set to True"
self.to_lowres_cond = LowresConditioner(
downsample_first=lowres_downsample_first,
blur_sigma=blur_sigma,
blur_kernel_size=blur_kernel_size,
)
# classifier free guidance
self.image_cond_drop_prob = image_cond_drop_prob
self.text_cond_drop_prob = text_cond_drop_prob
self.can_classifier_guidance = image_cond_drop_prob > 0.0 or text_cond_drop_prob > 0.0
# whether to clip when sampling
self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start
# dynamic thresholding settings, if clipping denoised during sampling
self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
# device tracker
self.register_buffer("_dummy", flow.Tensor([True]), persistent=False)
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number=None, unet=None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
devices = [module_device(unet) for unet in self.unets]
self.unets.cpu()
unet.cuda()
yield
for unet, device in zip(self.unets, devices):
unet.to(device)
def p_mean_variance(
self,
unet,
x,
t,
image_embed,
noise_scheduler,
text_encodings=None,
text_mask=None,
lowres_cond_img=None,
clip_denoised=True,
predict_x_start=False,
learned_variance=False,
cond_scale=1.0,
model_output=None,
):
assert not (
cond_scale != 1.0 and not self.can_classifier_guidance
), "the decoder was not trained with conditional dropout, "
"and thus one cannot use classifier free guidance (cond_scale anything other than 1)"
pred = default(
model_output,
lambda: unet.forward_with_cond_scale(
x,
t,
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
cond_scale=cond_scale,
lowres_cond_img=lowres_cond_img,
),
)
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim=1)
if predict_x_start:
x_recon = pred
else:
x_recon = noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
if clip_denoised:
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.0
if self.use_dynamic_thres:
s = flow.quantile(
rearrange(x_recon, "b ... -> b (...)").abs(),
self.dynamic_thres_percentile,
dim=-1,
)
s.clamp_(min=1.0)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(
x_start=x_recon, x_t=x, t=t
)
if learned_variance:
# if learned variance, posterio variance and posterior log variance are
# predicted by the network by an interpolation of the max and min log beta values
# eq 15 - https://arxiv.org/abs/2102.09672
min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = extract(flow.log(noise_scheduler.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp()
return model_mean, posterior_variance, posterior_log_variance
@flow.no_grad()
def p_sample(
self,
unet,
x,
t,
image_embed,
noise_scheduler,
text_encodings=None,
text_mask=None,
cond_scale=1.0,
lowres_cond_img=None,
predict_x_start=False,
learned_variance=False,
clip_denoised=True,
):
b = x.shape[0]
model_mean, _, model_log_variance = self.p_mean_variance(
unet,
x=x,
t=t,
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
cond_scale=cond_scale,
lowres_cond_img=lowres_cond_img,
clip_denoised=clip_denoised,
predict_x_start=predict_x_start,
noise_scheduler=noise_scheduler,
learned_variance=learned_variance,
)
noise = flow.randn(*x.shape, placement=get_default_placement(), sbp=get_default_sbp())
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@flow.no_grad()
def p_sample_loop(
self,
unet,
shape,
image_embed,
noise_scheduler,
predict_x_start=False,
learned_variance=False,
clip_denoised=True,
lowres_cond_img=None,
text_encodings=None,
text_mask=None,
cond_scale=1,
is_latent_diffusion=False,
):
b = shape[0]
img = flow.randn(*shape, placement=get_default_placement(), sbp=get_default_sbp())
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(
reversed(range(0, noise_scheduler.num_timesteps)),
desc="sampling loop time step",
total=noise_scheduler.num_timesteps,
):
img = self.p_sample(
unet,
img,
flow.full(
(b,),
i,
placement=get_default_placement(),
sbp=get_default_sbp(),
dtype=flow.long,
),
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
cond_scale=cond_scale,
lowres_cond_img=lowres_cond_img,
predict_x_start=predict_x_start,
noise_scheduler=noise_scheduler,
learned_variance=learned_variance,
clip_denoised=clip_denoised,
)
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
def p_losses(
self,
unet,
x_start,
times,
*,
image_embed,
noise_scheduler,
lowres_cond_img=None,
text_encodings=None,
text_mask=None,
predict_x_start=False,
noise=None,
learned_variance=False,
clip_denoised=False,
is_latent_diffusion=False,
):
noise = default(
noise,
lambda: flow.randn(
*x_start.shape, placement=get_default_placement(), sbp=get_default_sbp()
),
)
# normalize to [-1, 1]
if not is_latent_diffusion:
x_start = self.normalize_img(x_start)
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
# get x_t
x_noisy = noise_scheduler.q_sample(x_start=x_start, t=times, noise=noise)
model_output = unet(
x_noisy,
times,
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
lowres_cond_img=lowres_cond_img,
image_cond_drop_prob=self.image_cond_drop_prob,
text_cond_drop_prob=self.text_cond_drop_prob,
)
if learned_variance:
pred, _ = model_output.chunk(2, dim=1)
else:
pred = model_output
target = noise if not predict_x_start else x_start
loss = noise_scheduler.loss_fn(pred, target, reduction="none")
loss = reduce(loss, "b ... -> b (...)", "mean")
loss = noise_scheduler.p2_reweigh_loss(loss, times)
loss = loss.mean()
if not learned_variance:
# return simple loss if not using learned variance
return loss
# most of the code below is transcribed from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
# the Improved DDPM paper then further modified it so that the mean is detached
# (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
# it is questionable whether this is really needed, looking at some of the figures in the
# paper, but may as well stay faithful to their implementation
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(
x_start=x_start, x_t=x_noisy, t=times
)
model_mean, _, model_log_variance = self.p_mean_variance(
unet,
x=x_noisy,
t=times,
image_embed=image_embed,
noise_scheduler=noise_scheduler,
clip_denoised=clip_denoised,
learned_variance=True,
model_output=model_output,
)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean = model_mean.detach()
kl = normal_kl(
true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance
)
kl = meanflat(kl) * NAT
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=detached_model_mean, log_scales=0.5 * model_log_variance
)
decoder_nll = meanflat(decoder_nll) * NAT
# at the first timestep return the decoder NLL,
# otherwise KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses = flow.where(times == 0, decoder_nll, kl)
# weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
vb_loss = vb_losses.mean() * self.vb_loss_weight
return loss + vb_loss
@flow.no_grad()
@eval_decorator
def sample(
self,
image_embed=None,
text=None,
text_mask=None,
text_encodings=None,
batch_size=1,
cond_scale=1.0,
stop_at_unet_number=None,
distributed=False,
):
assert self.unconditional or exists(
image_embed
), "image embed must be present on sampling from decoder unless if trained unconditionally"
if not self.unconditional:
batch_size = image_embed.shape[0]
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (
self.condition_on_text_encodings and not exists(text_encodings)
), "text or text encodings must be passed into decoder if specified"
assert not (
not self.condition_on_text_encodings and exists(text_encodings)
), "decoder specified not to be conditioned on text, yet it is presented"
img = None
for (
unet_number,
unet,
vae,
channel,
image_size,
predict_x_start,
learned_variance,
noise_scheduler,
) in tqdm(
zip(
range(1, len(self.unets) + 1),
self.unets,
self.vaes,
self.sample_channels,
self.image_sizes,
self.predict_x_start,
self.learned_variance,
self.noise_schedulers,
)
):
context = null_context()
with context:
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)
if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size=image_size)
is_latent_diffusion = isinstance(vae, VQGanVAE)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
img = self.p_sample_loop(
unet,
shape,
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
cond_scale=cond_scale,
predict_x_start=predict_x_start,
learned_variance=learned_variance,
clip_denoised=not is_latent_diffusion,
lowres_cond_img=lowres_cond_img,
is_latent_diffusion=is_latent_diffusion,
noise_scheduler=noise_scheduler,
)
img = vae.decode(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
return img
def forward(
self,
image,
text=None,
image_embed=None,
text_encodings=None,
text_mask=None,
unet_number=None,
return_lowres_cond_image=False,
):
assert not (
len(self.unets) > 1 and not exists(unet_number)
), f"you must specify which unet you want trained, from a range of 1 to {len(self.unets)},"
" if you are training cascading DDPM (multiple unets)"
unet_number = default(unet_number, 1)
unet_index = unet_number - 1
unet = self.get_unet(unet_number)
vae = self.vaes[unet_index]
noise_scheduler = self.noise_schedulers[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, _, h, w, _, = (
*image.shape,
image.device,
)
check_shape(image, "b c h w", c=self.channels)
assert h >= target_image_size and w >= target_image_size
times = flow.randint(
0,
noise_scheduler.num_timesteps,
(b,),
placement=get_default_placement(),
sbp=get_default_sbp(),
dtype=flow.long,
)
if not exists(image_embed) and not self.unconditional:
assert exists(self.clip), "if you want to derive CLIP image embeddings automatically, "
"you must supply `clip` to the decoder on init"
image_embed, _ = self.clip.embed_image(image)
if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(
self.clip
), "if you are passing in raw text, you need to supply `clip` to the decoder"
_, text_encodings, text_mask = self.clip.embed_text(text)
assert not (
self.condition_on_text_encodings and not exists(text_encodings)
), "text or text encodings must be passed into decoder if specified"
assert not (
not self.condition_on_text_encodings and exists(text_encodings)
), "decoder specified not to be conditioned on text, yet it is presented"
lowres_cond_img = (
self.to_lowres_cond(
image,
target_image_size=target_image_size,
downsample_image_size=self.image_sizes[unet_index - 1],
)
if unet_number > 1
else None
)
image = resize_image_to(image, target_image_size)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size), p=1.0)
# make sure low res conditioner and image both get augmented the same way
image = aug(image)
lowres_cond_img = aug(lowres_cond_img, params=aug._params)
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
vae.eval()
with flow.no_grad():
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
losses = self.p_losses(
unet,
image,
times,
image_embed=image_embed,
text_encodings=text_encodings,
text_mask=text_mask,
lowres_cond_img=lowres_cond_img,
predict_x_start=predict_x_start,
learned_variance=learned_variance,
is_latent_diffusion=is_latent_diffusion,
noise_scheduler=noise_scheduler,
)
if not return_lowres_cond_image:
return losses
return losses, lowres_cond_img
# main class
class DALLE2(nn.Module):
def __init__(self, *, prior, decoder, prior_num_samples=2, **kwargs):
super().__init__()
# assert isinstance(prior, DiffusionPrior)
# assert isinstance(decoder, Decoder)
self.prior = prior
self.decoder = decoder
self.tokenizer = SimpleTokenizer()
self.prior_num_samples = prior_num_samples
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
self.to_pil = T.ToPILImage()
@flow.no_grad()
@eval_decorator
def forward(self, text, cond_scale=1.0, prior_cond_scale=1.0, return_pil_images=False):
device = module_device(self)
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
text = self.tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(
text, num_samples_per_batch=self.prior_num_samples, cond_scale=prior_cond_scale
)
text_cond = text if self.decoder_need_text_cond else None
images = self.decoder.sample(image_embed, text=text_cond, cond_scale=cond_scale)
if return_pil_images:
images = list(map(self.to_pil, images.unbind(dim=0)))
if one_text:
return first(images)
return images
from inspect import isfunction
from math import pi
import oneflow as flow
from einops import rearrange, repeat
from oneflow import einsum, nn
from libai.utils import distributed as dist
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return flow.cat(tensors, dim=dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = [i.squeeze(-1) for i in x.chunk(2, -1)] # x.unbind(dim = -1)
x = flow.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(freqs, t, start_index=0):
freqs = freqs.to(t.dtype)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to \
rotate in all the positions {rot_dim}"
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return flow.cat((t_left, t, t_right), dim=-1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
if exists(freq_ranges):
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
rotations = rearrange(rotations, "... r f -> ... (r f)")
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
return apply_rotary_emb(rotations, t, start_index=start_index)
# classes
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
learned_freq=False,
):
super().__init__()
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (theta ** (flow.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = flow.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = flow.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
self.cache = dict()
if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.register_buffer("freqs", freqs)
def rotate_queries_or_keys(self, t, seq_dim=-2):
seq_len = t.shape[seq_dim]
freqs = self.forward(
lambda: flow.arange(
seq_len,
placement=dist.get_layer_placement(0),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
),
cache_key=seq_len,
)
return apply_rotary_emb(freqs, t)
def forward(self, t, cache_key=None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if isfunction(t):
t = t()
freqs = self.freqs
freqs = flow.einsum("..., f -> ... f", t.to(freqs.dtype), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
# take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
# to give users a quick easy start to training DALL-E without doing BPE
import gzip
import html
import os
from functools import lru_cache
from pathlib import Path
import ftfy
import oneflow as flow
import regex as re
from .utils import import_or_print_error
# OpenAI simple tokenizer
@lru_cache()
def default_bpe():
return os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt.gz"
)
@lru_cache()
def bytes_to_unicode():
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path=default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = (
gzip.open(bpe_path).read().decode("utf-8").split("\n")
) # Path(bpe_path).read_text(encoding='utf8').split('\n')
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.vocab_size = 49408
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", # noqa
re.IGNORECASE,
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except ValueError:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens, remove_start_end=True, pad_tokens=set()):
if flow.is_tensor(tokens):
tokens = tokens.tolist()
if remove_start_end:
tokens = [token for token in tokens if token not in (49406, 40407, 0)]
text = "".join([self.decoder[token] for token in tokens if token not in pad_tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
def tokenize(self, texts, context_length=256, truncate_text=False):
if isinstance(texts, str):
texts = [texts]
all_tokens = [self.encode(text) for text in texts]
result = flow.zeros(len(all_tokens), context_length, dtype=flow.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate_text:
tokens = tokens[:context_length]
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}"
)
result[i, : len(tokens)] = flow.tensor(tokens)
return result
# tokenizer = SimpleTokenizer()
# YTTM tokenizer
class YttmTokenizer:
def __init__(self, bpe_path=None):
bpe_path = Path(bpe_path)
assert bpe_path.exists(), f"BPE json path {str(bpe_path)} does not exist"
self.yttm = import_or_print_error(
"youtokentome", "you need to install youtokentome by `pip install youtokentome`"
)
tokenizer = self.yttm.BPE(model=str(bpe_path))
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size()
def decode(self, tokens, pad_tokens=set()):
if flow.is_tensor(tokens):
tokens = tokens.tolist()
return self.tokenizer.decode(tokens, ignore_ids=pad_tokens.union({0}))
def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type=self.yttm.OutputType.ID)
return list(map(flow.tensor, encoded))
def tokenize(self, texts, context_length=256, truncate_text=False):
if isinstance(texts, str):
texts = [texts]
all_tokens = self.encode(texts)
result = flow.zeros(len(all_tokens), context_length, dtype=flow.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate_text:
tokens = tokens[:context_length]
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}"
)
result[i, : len(tokens)] = flow.tensor(tokens)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment