make_paired_data.py 3.91 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os

import colossalai
import torch
import torch.distributed as dist
from colossalai.cluster import DistCoordinator
from mmengine.runner import set_random_seed

from opensora.acceleration.parallel_states import set_sequence_parallel_group
from opensora.datasets import save_sample, prepare_dataloader
from opensora.registry import MODELS, SCHEDULERS, build_module, DATASETS
from opensora.utils.config_utils import parse_configs
from opensora.utils.misc import to_torch_dtype

import torch.nn.functional as F
from einops import rearrange
from opensora.datasets.high_order.degrade_video import degradation_process
from tqdm import tqdm

def main():
    # ======================================================
    # 1. cfg and init distributed env
    # ======================================================
    cfg = parse_configs(training=False)
    print(cfg)

    # init distributed
    if os.environ.get("WORLD_SIZE", None):
        use_dist = True
        colossalai.launch_from_torch({})
        coordinator = DistCoordinator()

        if coordinator.world_size > 1:
            set_sequence_parallel_group(dist.group.WORLD)
            enable_sequence_parallelism = True
        else:
            enable_sequence_parallelism = False
    else:
        use_dist = False
        enable_sequence_parallelism = False

    # ======================================================
    # 2. runtime variables
    # ======================================================
    torch.set_grad_enabled(False)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = to_torch_dtype(cfg.dtype)
    set_random_seed(seed=cfg.seed)

    # ======================================================
    # 3. build model & load weights
    # ======================================================
    cfg.dataset['data_path'] = cfg.data_path
    dataset = build_module(cfg.dataset, DATASETS)
    dataloader_args = dict(
        dataset=dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        seed=cfg.seed,
        shuffle=True,
        drop_last=True,
        pin_memory=False,
    )
    dataloader = prepare_dataloader(**dataloader_args)
    dataloader_iter = iter(dataloader)

    # ======================================================
    # 4. inference
    # ======================================================
    sample_idx = 0
    save_dir_gt = cfg.save_path + '/gt'
    save_dir_lq = cfg.save_path + '/lq'
    save_dir_txt = cfg.save_path + '/text'
    os.makedirs(save_dir_gt, exist_ok=True)
    os.makedirs(save_dir_lq, exist_ok=True)
    os.makedirs(save_dir_txt, exist_ok=True)

    # 4.1. batch generation with progress bar
    for _, batch in tqdm(enumerate(dataloader_iter), total=len(dataloader), desc="Processing 10K Batches"):
        x = batch.pop("video").to(device, dtype)  # [B, C, T, H, W], HR-video
        fps = batch.pop('fps')

        # generate LR-video
        lr, x = degradation_process(x)
        _, _, t, _, _ = lr.shape
        lr = rearrange(F.interpolate(rearrange(lr, "B C T H W -> (B T) C H W"), scale_factor=4, mode='bicubic'), "(B T) C H W -> B C T H W", T=t)
        y = batch.pop("text")

        # 4.4. save samples
        if not use_dist or coordinator.is_master():
            for i in range(0, lr.shape[0]):
                save_dir_gt_ = os.path.join(save_dir_gt, f"{sample_idx}")
                save_dir_lq_ = os.path.join(save_dir_lq, f"{sample_idx}")
                save_dir_txt_ = os.path.join(save_dir_txt, f"{sample_idx}.txt")

                save_sample(x[i], fps=fps / cfg.dataset['frame_interval'], save_path=save_dir_gt_)
                save_sample(lr[i], fps=fps / cfg.dataset['frame_interval'], save_path=save_dir_lq_)
                with open(save_dir_txt_, 'w', encoding='utf-8') as file:
                    file.write(y[i])

                sample_idx += 1

if __name__ == "__main__":
    main()