dalle2_config.py 1.5 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from libai.config import LazyCall
from configs.common.train import train
from dalle2.models import DiffusionPrior, DiffusionPriorNetwork, Unet, Decoder, DALLE2
from dalle2._clip import OpenAIClipAdapter
from omegaconf import DictConfig

clip = LazyCall(OpenAIClipAdapter)(name="")
swinir = DictConfig({"swinir_path": None})

prior = LazyCall(DiffusionPrior)(
    net=LazyCall(DiffusionPriorNetwork)(
        dim=768,
        depth=24,
        num_timesteps=1000,
        max_text_len=77,
        num_time_embeds=1,
        num_image_embeds=1,
        num_text_embeds=1,
        dim_head=64,
        heads=32,
        ff_mult=4,
        attn_dropout=0.05,
        ff_dropout=0.05,
        normformer=True,
    ),
    clip=clip,
    image_embed_dim=768,
    timesteps=1000,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,
)

unet1 = LazyCall(Unet)(
    dim=320,
    image_embed_dim=768,
    text_embed_dim=768,
    cond_dim=512,
    channels=3,
    dim_mults=(1, 2, 3, 4),
    num_resnet_blocks=4,
    attn_heads=8,
    attn_dim_head=64,
    sparse_attn=True,
    memory_efficient=True,
    cond_on_text_encodings=True,
    self_attn=[False, True, True, True],
)

decoder = LazyCall(Decoder)(
    unet=(unet1,),
    image_sizes=[
        64,
    ],
    clip=None,
    channels=3,
    timesteps=1000,
    loss_type="l2",
    beta_schedule=["cosine"],
    learned_variance=True,
)

model = LazyCall(DALLE2)(
    prior=prior,
    decoder=decoder,
    prior_weight_path="",
    decoder_weight_path="",
)