cogview3_plus.yaml 4.54 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
args:
  mode: inference
  relay_model: False
  load: "/home/models/CogView4/CogView3/cogview3-plus-3b/transformer"
  batch_size: 4
  grid_num_columns: 2
  input_type: txt
  input_file: "configs/test.txt"
  bf16: True
  force_inference: True
  sampling_image_size_x: 512
  sampling_image_size_y: 512
  sampling_latent_dim: 16
  output_dir: "outputs/cogview3_plus"
  deepspeed_config: { }

model:
  scale_factor: 1
  disable_first_stage_autocast: true
  log_keys:
    - txt

  denoiser_config:
    target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
    params:
      num_idx: 1000
      quantize_c_noise: False

      weighting_config:
        target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
      scaling_config:
        target: sgm.modules.diffusionmodules.denoiser_scaling.ZeroSNRScaling
      discretization_config:
        target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
        params:
          shift_scale: 4

  network_config:
    target: sgm.modules.diffusionmodules.dit.DiffusionTransformer
    params:
      in_channels: 16
      out_channels: 16
      hidden_size: 2560
      num_layers: 30
      patch_size: 2
      block_size: 16
      num_attention_heads: 64
      text_length: 224
      time_embed_dim: 512
      num_classes: sequential
      adm_in_channels: 1536

      modules:
        pos_embed_config:
          target: sgm.modules.diffusionmodules.dit.PositionEmbeddingMixin
          params:
            max_height: 128
            max_width: 128
            max_length: 4096
        patch_embed_config:
          target: sgm.modules.diffusionmodules.dit.ImagePatchEmbeddingMixin
          params:
            text_hidden_size: 4096
        attention_config:
          target: sgm.modules.diffusionmodules.dit.AdalnAttentionMixin
          params:
            qk_ln: true
        final_layer_config:
          target: sgm.modules.diffusionmodules.dit.FinalLayerMixin

  conditioner_config:
    target: sgm.modules.GeneralConditioner
    params:
      emb_models:
        # crossattn cond
        - is_trainable: False
          input_key: txt
          target: sgm.modules.encoders.modules.FrozenT5Embedder
          params:
            model_dir: "/home/models/CogView4/t5-v1_1-xxl"
            max_length: 224
        # vector cond
        - is_trainable: False
          input_key: original_size_as_tuple
          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          params:
            outdim: 256  # multiplied by two
        # vector cond
        - is_trainable: False
          input_key: crop_coords_top_left
          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          params:
            outdim: 256  # multiplied by two
        # vector cond
        - is_trainable: False
          input_key: target_size_as_tuple
          target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          params:
            outdim: 256  # multiplied by two
  
  first_stage_config:
    target: sgm.models.autoencoder.AutoencodingEngine
    params:
      ckpt_path: "/home/models/CogView4/CogView3/cogview3-plus-3b/vae/imagekl_ch16.pt"
      monitor: val/rec_loss

      loss_config:
        target: torch.nn.Identity
      
      regularizer_config:
        target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

      encoder_config:
        target: sgm.modules.diffusionmodules.model.Encoder
        params:
          attn_type: vanilla
          double_z: true
          z_channels: 16
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [ 1, 4, 8, 8 ]
          num_res_blocks: 3
          attn_resolutions: [ ]
          mid_attn: False
          dropout: 0.0
      
      decoder_config:
        target: sgm.modules.diffusionmodules.model.Decoder
        params:
          attn_type: vanilla
          double_z: true
          z_channels: 16
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [ 1, 4, 8, 8 ]
          num_res_blocks: 3
          attn_resolutions: [ ]
          mid_attn: False
          dropout: 0.0

  loss_fn_config:
    target: torch.nn.Identity
  
  sampler_config:
    target: sgm.modules.diffusionmodules.sampling.ZeroSNRDDIMSampler
    params:
      num_steps: 10
      verbose: True

      discretization_config:
        target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
        params:
          shift_scale: 4

      guider_config:
        target: sgm.modules.diffusionmodules.guiders.VanillaCFG
        params:
          scale: 5