cp_enc_dec.py 46.5 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
import math

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Union
from einops import rearrange
from sgm.util import (get_context_parallel_group,
                      get_context_parallel_group_rank,
                      get_context_parallel_rank,
                      get_context_parallel_world_size)
# try:
from vae_modules.utils import SafeConv3d as Conv3d

# except:
#     # Degrade to normal Conv3d if SafeConv3d is not available
#     from torch.nn import Conv3d


def cast_tuple(t, length=1):
    return t if isinstance(t, tuple) else ((t, ) * length)


def divisible_by(num, den):
    return (num % den) == 0


def is_odd(n):
    return not divisible_by(n, 2)


def exists(v):
    return v is not None


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


def nonlinearity(x):
    # swish
    return x * torch.sigmoid(x)


def leaky_relu(p=0.1):
    return nn.LeakyReLU(p)


def _split(input_, dim):
    cp_world_size = get_context_parallel_world_size()

    if cp_world_size == 1:
        return input_

    cp_rank = get_context_parallel_rank()

    # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)

    inpu_first_frame_ = input_.transpose(0,
                                         dim)[:1].transpose(0,
                                                            dim).contiguous()
    input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
    dim_size = input_.size()[dim] // cp_world_size

    input_list = torch.split(input_, dim_size, dim=dim)
    output = input_list[cp_rank]

    if cp_rank == 0:
        output = torch.cat([inpu_first_frame_, output], dim=dim)
    output = output.contiguous()

    # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)

    return output


def _gather(input_, dim):
    cp_world_size = get_context_parallel_world_size()

    # Bypass the function if context parallel is 1
    if cp_world_size == 1:
        return input_

    group = get_context_parallel_group()
    cp_rank = get_context_parallel_rank()

    # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)

    input_first_frame_ = input_.transpose(0,
                                          dim)[:1].transpose(0,
                                                             dim).contiguous()
    if cp_rank == 0:
        input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()

    tensor_list = [
        torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))
    ] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]

    if cp_rank == 0:
        input_ = torch.cat([input_first_frame_, input_], dim=dim)

    tensor_list[cp_rank] = input_
    torch.distributed.all_gather(tensor_list, input_, group=group)

    output = torch.cat(tensor_list, dim=dim).contiguous()

    # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)

    return output


def _conv_split(input_, dim, kernel_size):
    cp_world_size = get_context_parallel_world_size()

    # Bypass the function if context parallel is 1
    if cp_world_size == 1:
        return input_

    # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)

    cp_rank = get_context_parallel_rank()

    dim_size = (input_.size()[dim] - kernel_size) // cp_world_size

    if cp_rank == 0:
        output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(
            dim, 0)
    else:
        # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
        output = input_.transpose(
            dim, 0)[cp_rank * dim_size + kernel_size:(cp_rank + 1) * dim_size +
                    kernel_size].transpose(dim, 0)
    output = output.contiguous()

    # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)

    return output


def _conv_gather(input_, dim, kernel_size):
    cp_world_size = get_context_parallel_world_size()

    # Bypass the function if context parallel is 1
    if cp_world_size == 1:
        return input_

    group = get_context_parallel_group()
    cp_rank = get_context_parallel_rank()

    # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)

    input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(
        0, dim).contiguous()
    if cp_rank == 0:
        input_ = input_.transpose(0, dim)[kernel_size:].transpose(
            0, dim).contiguous()
    else:
        input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(
            0, dim).contiguous()

    tensor_list = [
        torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))
    ] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
    if cp_rank == 0:
        input_ = torch.cat([input_first_kernel_, input_], dim=dim)

    tensor_list[cp_rank] = input_
    torch.distributed.all_gather(tensor_list, input_, group=group)

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=dim).contiguous()

    # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)

    return output


def _pass_from_previous_rank(input_, dim, kernel_size):
    # Bypass the function if kernel size is 1
    if kernel_size == 1:
        return input_

    group = get_context_parallel_group()
    cp_rank = get_context_parallel_rank()
    cp_group_rank = get_context_parallel_group_rank()
    cp_world_size = get_context_parallel_world_size()

    # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)

    global_rank = torch.distributed.get_rank()
    global_world_size = torch.distributed.get_world_size()

    input_ = input_.transpose(0, dim)

    # pass from last rank
    send_rank = global_rank + 1
    recv_rank = global_rank - 1
    if send_rank % cp_world_size == 0:
        send_rank -= cp_world_size
    if recv_rank % cp_world_size == cp_world_size - 1:
        recv_rank += cp_world_size

    if cp_rank < cp_world_size - 1:
        req_send = torch.distributed.isend(input_[-kernel_size +
                                                  1:].contiguous(),
                                           send_rank,
                                           group=group)
    if cp_rank > 0:
        recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
        req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)

    if cp_rank == 0:
        input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
    else:
        req_recv.wait()
        input_ = torch.cat([recv_buffer, input_], dim=0)

    input_ = input_.transpose(0, dim).contiguous()

    # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)

    return input_


def _fake_cp_pass_from_previous_rank(input_,
                                     dim,
                                     kernel_size,
                                     cache_padding=None):
    # Bypass the function if kernel size is 1
    if kernel_size == 1:
        return input_

    group = get_context_parallel_group()
    cp_rank = get_context_parallel_rank()
    cp_group_rank = get_context_parallel_group_rank()
    cp_world_size = get_context_parallel_world_size()

    # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)

    global_rank = torch.distributed.get_rank()
    global_world_size = torch.distributed.get_world_size()

    input_ = input_.transpose(0, dim)

    # pass from last rank
    send_rank = global_rank + 1
    recv_rank = global_rank - 1
    if send_rank % cp_world_size == 0:
        send_rank -= cp_world_size
    if recv_rank % cp_world_size == cp_world_size - 1:
        recv_rank += cp_world_size

    # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
    # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
    # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
    # req_recv.wait()
    recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
    if cp_rank < cp_world_size - 1:
        req_send = torch.distributed.isend(input_[-kernel_size +
                                                  1:].contiguous(),
                                           send_rank,
                                           group=group)
    if cp_rank > 0:
        req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
    # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
    # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)

    if cp_rank == 0:
        if cache_padding is not None:
            input_ = torch.cat(
                [cache_padding.transpose(0, dim).to(input_.device), input_],
                dim=0)
        else:
            input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_],
                               dim=0)
    else:
        req_recv.wait()
        input_ = torch.cat([recv_buffer, input_], dim=0)

    input_ = input_.transpose(0, dim).contiguous()
    return input_


def _drop_from_previous_rank(input_, dim, kernel_size):
    input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim)
    return input_


class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, dim, kernel_size):
        ctx.dim = dim
        ctx.kernel_size = kernel_size
        return _conv_split(input_, dim, kernel_size)

    @staticmethod
    def backward(ctx, grad_output):
        return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None


class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, dim, kernel_size):
        ctx.dim = dim
        ctx.kernel_size = kernel_size
        return _conv_gather(input_, dim, kernel_size)

    @staticmethod
    def backward(ctx, grad_output):
        return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None


class _ConvolutionPassFromPreviousRank(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, dim, kernel_size):
        ctx.dim = dim
        ctx.kernel_size = kernel_size
        return _pass_from_previous_rank(input_, dim, kernel_size)

    @staticmethod
    def backward(ctx, grad_output):
        return _drop_from_previous_rank(grad_output, ctx.dim,
                                        ctx.kernel_size), None, None


class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, dim, kernel_size, cache_padding):
        ctx.dim = dim
        ctx.kernel_size = kernel_size
        return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size,
                                                cache_padding)

    @staticmethod
    def backward(ctx, grad_output):
        return _drop_from_previous_rank(grad_output, ctx.dim,
                                        ctx.kernel_size), None, None, None


def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
    return _ConvolutionScatterToContextParallelRegion.apply(
        input_, dim, kernel_size)


def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
    return _ConvolutionGatherFromContextParallelRegion.apply(
        input_, dim, kernel_size)


def conv_pass_from_last_rank(input_, dim, kernel_size):
    return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)


def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
    return _FakeCPConvolutionPassFromPreviousRank.apply(
        input_, dim, kernel_size, cache_padding)


class ContextParallelCausalConv3d(nn.Module):

    def __init__(self,
                 chan_in,
                 chan_out,
                 kernel_size: Union[int, Tuple[int, int, int]],
                 stride=1,
                 **kwargs):
        super().__init__()
        kernel_size = cast_tuple(kernel_size, 3)

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)

        time_pad = time_kernel_size - 1
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        self.height_pad = height_pad
        self.width_pad = width_pad
        self.time_pad = time_pad
        self.time_kernel_size = time_kernel_size
        self.temporal_dim = 2

        stride = (stride, stride, stride)
        dilation = (1, 1, 1)
        self.conv = Conv3d(chan_in,
                           chan_out,
                           kernel_size,
                           stride=stride,
                           dilation=dilation,
                           **kwargs)
        self.cache_padding = None

    def forward(self, input_, clear_cache=True):
        # if input_.shape[2] == 1: # handle image
        #     # first frame padding
        #     input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
        # else:
        #     input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)

        # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
        # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)

        # output_parallel = self.conv(input_parallel)
        # output = output_parallel
        # return output

        input_parallel = fake_cp_pass_from_previous_rank(
            input_, self.temporal_dim, self.time_kernel_size,
            self.cache_padding)

        del self.cache_padding
        self.cache_padding = None
        if not clear_cache:
            cp_rank, cp_world_size = get_context_parallel_rank(
            ), get_context_parallel_world_size()
            global_rank = torch.distributed.get_rank()
            if cp_world_size == 1:
                self.cache_padding = (
                    input_parallel[:, :, -self.time_kernel_size +
                                   1:].contiguous().detach().clone().cpu())
            else:
                if cp_rank == cp_world_size - 1:
                    torch.distributed.isend(
                        input_parallel[:, :, -self.time_kernel_size +
                                       1:].contiguous(),
                        global_rank + 1 - cp_world_size,
                        group=get_context_parallel_group(),
                    )
                if cp_rank == 0:
                    recv_buffer = torch.empty_like(
                        input_parallel[:, :, -self.time_kernel_size +
                                       1:]).contiguous()
                    torch.distributed.recv(recv_buffer,
                                           global_rank - 1 + cp_world_size,
                                           group=get_context_parallel_group())
                    self.cache_padding = recv_buffer.contiguous().detach(
                    ).clone().cpu()

        padding_2d = (self.width_pad, self.width_pad, self.height_pad,
                      self.height_pad)
        input_parallel = F.pad(input_parallel,
                               padding_2d,
                               mode='constant',
                               value=0)

        output_parallel = self.conv(input_parallel)
        output = output_parallel
        return output


class ContextParallelGroupNorm(torch.nn.GroupNorm):

    def forward(self, input_):
        gather_flag = input_.shape[2] > 1
        if gather_flag:
            input_ = conv_gather_from_context_parallel_region(input_,
                                                              dim=2,
                                                              kernel_size=1)
        output = super().forward(input_)
        if gather_flag:
            output = conv_scatter_to_context_parallel_region(output,
                                                             dim=2,
                                                             kernel_size=1)
        return output


def Normalize(in_channels, gather=False, **kwargs):  # same for 3D and 2D
    if gather:
        return ContextParallelGroupNorm(num_groups=32,
                                        num_channels=in_channels,
                                        eps=1e-6,
                                        affine=True)
    else:
        return torch.nn.GroupNorm(num_groups=32,
                                  num_channels=in_channels,
                                  eps=1e-6,
                                  affine=True)


class SpatialNorm3D(nn.Module):

    def __init__(
        self,
        f_channels,
        zq_channels,
        freeze_norm_layer=False,
        add_conv=False,
        pad_mode='constant',
        gather=False,
        **norm_layer_params,
    ):
        super().__init__()
        if gather:
            self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels,
                                                       **norm_layer_params)
        else:
            self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels,
                                                 **norm_layer_params)
        # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
        if freeze_norm_layer:
            for p in self.norm_layer.parameters:
                p.requires_grad = False

        self.add_conv = add_conv
        if add_conv:
            self.conv = ContextParallelCausalConv3d(
                chan_in=zq_channels,
                chan_out=zq_channels,
                kernel_size=3,
            )

        self.conv_y = ContextParallelCausalConv3d(
            chan_in=zq_channels,
            chan_out=f_channels,
            kernel_size=1,
        )
        self.conv_b = ContextParallelCausalConv3d(
            chan_in=zq_channels,
            chan_out=f_channels,
            kernel_size=1,
        )

    def forward(self, f, zq, clear_fake_cp_cache=True):
        if hasattr(self, 'force_split') and self.force_split:
            force_split = True
        else:
            force_split = False

        if f.shape[2] > 1 and f.shape[2] % 2 == 1 or force_split:
            f_first, f_rest = f[:, :, :1], f[:, :, 1:]
            f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
            zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
            zq_first = torch.nn.functional.interpolate(zq_first,
                                                       size=f_first_size,
                                                       mode='nearest')
            zq_rest = torch.nn.functional.interpolate(zq_rest,
                                                      size=f_rest_size,
                                                      mode='nearest')
            zq = torch.cat([zq_first, zq_rest], dim=2)
        else:
            zq = torch.nn.functional.interpolate(zq,
                                                 size=f.shape[-3:],
                                                 mode='nearest')

        if self.add_conv:
            zq = self.conv(zq, clear_cache=clear_fake_cp_cache)

        # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
        norm_f = self.norm_layer(f)
        # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)

        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
        return new_f


def Normalize3D(
    in_channels,
    zq_ch,
    add_conv,
    gather=False,
):
    return SpatialNorm3D(
        in_channels,
        zq_ch,
        gather=gather,
        freeze_norm_layer=False,
        add_conv=add_conv,
        num_groups=32,
        eps=1e-6,
        affine=True,
    )


# class Upsample3D(nn.Module):
#     def __init__(
#         self,
#         in_channels,
#         with_conv,
#         compress_time=False,
#     ):
#         super().__init__()
#         self.with_conv = with_conv
#         if self.with_conv:
#             self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
#         self.compress_time = compress_time

#     def forward(self, x):
#         if hasattr(self, "force_split") and self.force_split:
#             force_split = True
#         else:
#             force_split = False

#         if self.compress_time and x.shape[2] > 1:
#             if x.shape[2] % 2 == 1 or force_split:
#                 # split first frame
#                 x_first, x_rest = x[:, :, 0], x[:, :, 1:]

#                 x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
#                 x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
#                 x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
#             else:
#                 x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")

#         else:
#             # only interpolate 2D
#             t = x.shape[2]
#             x = rearrange(x, "b c t h w -> (b t) c h w")
#             x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
#             x = rearrange(x, "(b t) c h w -> b c t h w", t=t)

#         if self.with_conv:
#             t = x.shape[2]
#             x = rearrange(x, "b c t h w -> (b t) c h w")
#             x = self.conv(x)
#             x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
#         return x


class Upsample3D(nn.Module):

    def __init__(
        self,
        in_channels,
        with_conv,
        compress_time=False,
    ):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        self.compress_time = compress_time
        self.scale_factor = 2

    def forward(self, x):
        if hasattr(self, 'force_split') and self.force_split:
            force_split = True
        else:
            force_split = False

        if self.compress_time and x.shape[2] > 1:
            if x.shape[2] % 2 == 1 or force_split:
                # split first frame
                x_first, x_rest = x[:, :, 0], x[:, :, 1:]

                x_first = torch.nn.functional.interpolate(
                    x_first, scale_factor=self.scale_factor, mode='nearest')
                x_rest = torch.nn.functional.interpolate(
                    x_rest, scale_factor=self.scale_factor, mode='nearest')
                x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
            else:
                x = torch.nn.functional.interpolate(
                    x, scale_factor=self.scale_factor, mode='nearest')

        else:
            # only interpolate 2D
            t = x.shape[2]
            x = rearrange(x, 'b c t h w -> (b t) c h w')
            x = torch.nn.functional.interpolate(x,
                                                scale_factor=self.scale_factor,
                                                mode='nearest')
            x = rearrange(x, '(b t) c h w -> b c t h w', t=t)

        if self.with_conv:
            t = x.shape[2]
            x = rearrange(x, 'b c t h w -> (b t) c h w')
            x = self.conv(x)
            x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
        return x


class DownSample3D(nn.Module):

    def __init__(self,
                 in_channels,
                 with_conv,
                 compress_time=False,
                 out_channels=None):
        super().__init__()
        self.with_conv = with_conv
        if out_channels is None:
            out_channels = in_channels
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        out_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)
        self.compress_time = compress_time

    def forward(self, x):
        if self.compress_time and x.shape[2] > 1:
            h, w = x.shape[-2:]
            x = rearrange(x, 'b c t h w -> (b h w) c t')

            if x.shape[-1] % 2 == 1:
                # split first frame
                x_first, x_rest = x[..., 0], x[..., 1:]

                if x_rest.shape[-1] > 0:
                    try:
                        x_rest = torch.nn.functional.avg_pool1d(x_rest,
                                                                kernel_size=2,
                                                                stride=2)
                    except:
                        # for loop the avg_pool1d
                        print(
                            '######### for loop the avg_pool1d in else ###########'
                        )
                        x_rest_list = x_rest.split(len(x_rest) // 4, dim=0)
                        x_rest = torch.cat([
                            torch.nn.functional.avg_pool1d(
                                xr, kernel_size=2, stride=2)
                            for xr in x_rest_list
                        ],
                                           dim=0)

                x = torch.cat([x_first[..., None], x_rest], dim=-1)
                x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
            else:
                try:
                    x = torch.nn.functional.avg_pool1d(x,
                                                       kernel_size=2,
                                                       stride=2)
                except:  # for loop the avg_pool1d
                    print(
                        '######### for loop the avg_pool1d in else ###########'
                    )
                    x_list = x.split(len(x) // 4, dim=0)
                    x = torch.cat([
                        torch.nn.functional.avg_pool1d(
                            xr, kernel_size=2, stride=2) for xr in x_list
                    ],
                                  dim=0)

                x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)

        if self.with_conv:
            pad = (0, 1, 0, 1)
            x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
            t = x.shape[2]
            x = rearrange(x, 'b c t h w -> (b t) c h w')
            x = self.conv(x)
            x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
        else:
            t = x.shape[2]
            x = rearrange(x, 'b c t h w -> (b t) c h w')
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
            x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
        return x


class ContextParallelResnetBlock3D(nn.Module):

    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout,
        temb_channels=512,
        zq_ch=None,
        add_conv=False,
        gather_norm=False,
        normalization=Normalize,
    ):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = normalization(
            in_channels,
            zq_ch=zq_ch,
            add_conv=add_conv,
            gather=gather_norm,
        )

        self.conv1 = ContextParallelCausalConv3d(
            chan_in=in_channels,
            chan_out=out_channels,
            kernel_size=3,
        )
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        self.norm2 = normalization(
            out_channels,
            zq_ch=zq_ch,
            add_conv=add_conv,
            gather=gather_norm,
        )
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = ContextParallelCausalConv3d(
            chan_in=out_channels,
            chan_out=out_channels,
            kernel_size=3,
        )
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = ContextParallelCausalConv3d(
                    chan_in=in_channels,
                    chan_out=out_channels,
                    kernel_size=3,
                )
            else:
                self.nin_shortcut = Conv3d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                )

    def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
        h = x

        # if isinstance(self.norm1, torch.nn.GroupNorm):
        #     h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
        if zq is not None:
            h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
        else:
            h = self.norm1(h)
        # if isinstance(self.norm1, torch.nn.GroupNorm):
        #     h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)

        h = nonlinearity(h)
        h = self.conv1(h, clear_cache=clear_fake_cp_cache)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]

        # if isinstance(self.norm2, torch.nn.GroupNorm):
        #     h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
        if zq is not None:
            h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
        else:
            h = self.norm2(h)
        # if isinstance(self.norm2, torch.nn.GroupNorm):
        #     h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)

        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h, clear_cache=clear_fake_cp_cache)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache)
            else:
                x = self.nin_shortcut(x)

        return x + h


class ContextParallelEncoder3D(nn.Module):

    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        double_z=True,
        pad_mode='first',
        temporal_compress_times=4,
        gather_norm=False,
        **ignore_kwargs,
    ):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # log2 of temporal_compress_times
        self.temporal_compress_level = int(np.log2(temporal_compress_times))

        self.conv_in = ContextParallelCausalConv3d(
            chan_in=in_channels,
            chan_out=self.ch,
            kernel_size=3,
        )

        curr_res = resolution
        in_ch_mult = (1, ) + tuple(ch_mult)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(
                    ContextParallelResnetBlock3D(
                        in_channels=block_in,
                        out_channels=block_out,
                        dropout=dropout,
                        temb_channels=self.temb_ch,
                        gather_norm=gather_norm,
                    ))
                block_in = block_out
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                if i_level < self.temporal_compress_level:
                    down.downsample = DownSample3D(block_in,
                                                   resamp_with_conv,
                                                   compress_time=True)
                else:
                    down.downsample = DownSample3D(block_in,
                                                   resamp_with_conv,
                                                   compress_time=False)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ContextParallelResnetBlock3D(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            gather_norm=gather_norm,
        )

        self.mid.block_2 = ContextParallelResnetBlock3D(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            gather_norm=gather_norm,
        )

        # end
        self.norm_out = Normalize(block_in, gather=gather_norm)

        self.conv_out = ContextParallelCausalConv3d(
            chan_in=block_in,
            chan_out=2 * z_channels if double_z else z_channels,
            kernel_size=3,
        )

    def forward(self, x, **kwargs):
        # timestep embedding
        temb = None

        # downsampling
        h = self.conv_in(x)
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](h, temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
            if i_level != self.num_resolutions - 1:
                h = self.down[i_level].downsample(h)

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.block_2(h, temb)

        # end
        # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
        h = self.norm_out(h)
        # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)

        h = nonlinearity(h)
        h = self.conv_out(h)

        return h


class ContextParallelDecoder3D(nn.Module):

    def __init__(
            self,
            *,
            ch,  # 128
            out_ch,  # 3
            ch_mult=(1, 2, 4, 8),
            num_res_blocks,  # 3
            attn_resolutions,  # []
            dropout=0.0,  # 0.0
            resamp_with_conv=True,  # True
            in_channels,  # 3
            resolution,  # 256
            z_channels,  # 16
            give_pre_end=False,  # False
            zq_ch=None,  # None
            add_conv=False,
            pad_mode='first',  # "first"
            temporal_compress_times=4,  # 4
            gather_norm=False,  # False
            **ignorekwargs,  # {'double_z': True}
    ):
        super().__init__()
        self.ch = ch  # 128
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end

        # log2 of temporal_compress_times
        self.temporal_compress_level = int(np.log2(temporal_compress_times))

        if zq_ch is None:
            zq_ch = z_channels

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1, ) + tuple(ch_mult)
        block_in = ch * ch_mult[self.num_resolutions - 1]
        curr_res = resolution // 2**(self.num_resolutions - 1)
        self.z_shape = (1, z_channels, curr_res, curr_res)
        print('Working with z of shape {} = {} dimensions.'.format(
            self.z_shape, np.prod(self.z_shape)))

        self.conv_in = ContextParallelCausalConv3d(
            chan_in=z_channels,
            chan_out=block_in,
            kernel_size=3,
        )

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ContextParallelResnetBlock3D(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            zq_ch=zq_ch,
            add_conv=add_conv,
            normalization=Normalize3D,
            gather_norm=gather_norm,
        )

        self.mid.block_2 = ContextParallelResnetBlock3D(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            zq_ch=zq_ch,
            add_conv=add_conv,
            normalization=Normalize3D,
            gather_norm=gather_norm,
        )

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch * ch_mult[i_level]
            for i_block in range(self.num_res_blocks + 1):
                block.append(
                    ContextParallelResnetBlock3D(
                        in_channels=block_in,
                        out_channels=block_out,
                        temb_channels=self.temb_ch,
                        dropout=dropout,
                        zq_ch=zq_ch,
                        add_conv=add_conv,
                        normalization=Normalize3D,
                        gather_norm=gather_norm,
                    ))
                block_in = block_out
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                if i_level < self.num_resolutions - self.temporal_compress_level:
                    up.upsample = Upsample3D(block_in,
                                             with_conv=resamp_with_conv,
                                             compress_time=False)
                else:
                    up.upsample = Upsample3D(block_in,
                                             with_conv=resamp_with_conv,
                                             compress_time=True)
            self.up.insert(0, up)

        self.norm_out = Normalize3D(block_in,
                                    zq_ch,
                                    add_conv=add_conv,
                                    gather=gather_norm)

        self.conv_out = ContextParallelCausalConv3d(
            chan_in=block_in,
            chan_out=out_ch,
            kernel_size=3,
        )

    def forward(self, z, clear_fake_cp_cache=True, **kwargs):
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        t = z.shape[2]
        # z to block_in

        zq = z
        h = self.conv_in(z, clear_cache=clear_fake_cp_cache)

        # middle
        h = self.mid.block_1(h,
                             temb,
                             zq,
                             clear_fake_cp_cache=clear_fake_cp_cache)
        h = self.mid.block_2(h,
                             temb,
                             zq,
                             clear_fake_cp_cache=clear_fake_cp_cache)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](
                    h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h, zq)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
        h = nonlinearity(h)
        h = self.conv_out(h, clear_cache=clear_fake_cp_cache)

        return h

    def get_last_layer(self):
        return self.conv_out.conv.weight


class SlidingContextParallelEncoder3D(ContextParallelEncoder3D):

    def forward(self, x, clear_fake_cp_cache=True):
        # timestep embedding
        temb = None

        # downsampling
        h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](
                    h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
            if i_level != self.num_resolutions - 1:
                h = self.down[i_level].downsample(h)

        # middle
        h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
        h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)

        # end
        # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
        h = self.norm_out(h)
        # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)

        h = nonlinearity(h)
        h = self.conv_out(h, clear_cache=clear_fake_cp_cache)

        return h


class LatentUpscaler(ContextParallelDecoder3D):

    def __init__(
            self,
            *,
            ch=128,  # 128
            out_ch=16,  # 3
            scale_factor=2,  # 3
            ch_mult=(2, 4),  # (1, 2, 4, 8)
            num_res_blocks=2,  # 3
            attn_resolutions=[],  # []
            dropout=0.0,  # 0.0
            resamp_with_conv=True,  # True
            in_channels=3,  # 3
            resolution=256,  # 256
            z_channels=16,  # 16
            give_pre_end=False,  # False
            zq_ch=None,  # None
            add_conv=False,
            pad_mode='first',  # "first"
            temporal_compress_times=4,  # 4
            gather_norm=False,  # False
            double_z=True):
        super(ContextParallelDecoder3D, self).__init__()
        self.ch = ch  # 128
        self.temb_ch = 0
        self.scale_factor = scale_factor
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end

        # log2 of temporal_compress_times
        self.temporal_compress_level = int(np.log2(temporal_compress_times))

        if zq_ch is None:
            zq_ch = z_channels

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1, ) + tuple(ch_mult)
        block_in = ch * ch_mult[self.num_resolutions - 1]
        curr_res = resolution // 2**(self.num_resolutions - 1)
        self.z_shape = (1, z_channels, curr_res, curr_res)
        print('Working with z of shape {} = {} dimensions.'.format(
            self.z_shape, np.prod(self.z_shape)))

        self.conv_in = ContextParallelCausalConv3d(
            chan_in=z_channels,
            chan_out=block_in,
            kernel_size=3,
        )

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ContextParallelResnetBlock3D(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            zq_ch=zq_ch,
            add_conv=add_conv,
            normalization=Normalize3D,
            gather_norm=gather_norm,
        )

        self.mid.block_2 = ContextParallelResnetBlock3D(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            zq_ch=zq_ch,
            add_conv=add_conv,
            normalization=Normalize3D,
            gather_norm=gather_norm,
        )

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch * ch_mult[i_level]
            for i_block in range(self.num_res_blocks + 1):
                block.append(
                    ContextParallelResnetBlock3D(
                        in_channels=block_in,
                        out_channels=block_out,
                        temb_channels=self.temb_ch,
                        dropout=dropout,
                        zq_ch=zq_ch,
                        add_conv=add_conv,
                        normalization=Normalize3D,
                        gather_norm=gather_norm,
                    ))
                block_in = block_out
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                if i_level < self.num_resolutions - self.temporal_compress_level:
                    up.upsample = Upsample3D(block_in,
                                             with_conv=resamp_with_conv,
                                             compress_time=False)
                else:
                    up.upsample = Upsample3D(block_in,
                                             with_conv=resamp_with_conv,
                                             compress_time=True)
            self.up.insert(0, up)

        self.norm_out = Normalize3D(block_in,
                                    zq_ch,
                                    add_conv=add_conv,
                                    gather=gather_norm)

        self.conv_out = ContextParallelCausalConv3d(
            chan_in=block_in,
            chan_out=out_ch,
            kernel_size=3,
        )
        # for close Upsample3D compress_time
        for n, m in self.named_modules():
            if isinstance(m, Upsample3D):
                m.compress_time = False
                m.scale_factor = 2


# mini test latent upscaler
if __name__ == '__main__':
    import torch
    from einops import rearrange
    from torch import nn
    from torch.nn import functional as F

    x = torch.randn(2, 16, 49, 60, 90).cuda()
    # b,c,t,h,w
    print(x.shape)
    model = LatentUpscaler(
        ch=128,  # 128
        out_ch=16,  # 3
        ch_mult=(2, 4),  # (1, 2, 4, 8)
        num_res_blocks=2,  # 3
        attn_resolutions=[],  # []
        dropout=0.0,  # 0.0
        resamp_with_conv=True,  # True
        in_channels=3,  # 3
        resolution=256,  # 256
        z_channels=16,  # 16
        give_pre_end=False,  # False
        zq_ch=None,  # None
        add_conv=False,
        pad_mode='first',  # "first"
        temporal_compress_times=4,  # 4
        gather_norm=False,  # False
        double_z=True)
    print(model)
    out = model(x)
    print(out.shape)