model.py 4.75 KB
Newer Older
Watebear's avatar
Watebear committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from safetensors import safe_open
import os
import glob
import math
import json

from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer
from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer


class CogvideoxModel:
    pre_weight_class = CogvideoxPreWeights
    post_weight_class = CogvideoxPostWeights
    transformer_weight_class = CogvideoxTransformerWeights

    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda")
        self._init_infer_class()
        self._init_weights()
        self._init_infer()

    def _init_infer_class(self):
        self.pre_infer_class = CogvideoxPreInfer
        self.post_infer_class = CogvideoxPostInfer
        self.transformer_infer_class = CogvideoxTransformerInfer

    def _load_safetensor_to_dict(self, file_path):
        with safe_open(file_path, framework="pt") as f:
            tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
        return tensor_dict

    def _load_ckpt(self):
        safetensors_pattern = os.path.join(self.config.model_path, "transformer", "*.safetensors")
        safetensors_files = glob.glob(safetensors_pattern)

        if not safetensors_files:
            raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
        weight_dict = {}
        for file_path in safetensors_files:
            file_weights = self._load_safetensor_to_dict(file_path)
            weight_dict.update(file_weights)
        return weight_dict

    def _init_weights(self):
        weight_dict = self._load_ckpt()
        with open(os.path.join(self.config.model_path, "transformer", "config.json"), "r") as f:
            transformer_cfg = json.load(f)
        # init weights
        self.pre_weight = self.pre_weight_class(transformer_cfg)
        self.transformer_weights = self.transformer_weight_class(transformer_cfg)
        self.post_weight = self.post_weight_class(transformer_cfg)
        # load weights
        self.pre_weight.load_weights(weight_dict)
        self.transformer_weights.load_weights(weight_dict)
        self.post_weight.load_weights(weight_dict)

    def _init_infer(self):
        self.pre_infer = self.pre_infer_class(self.config)
        self.transformer_infer = self.transformer_infer_class(self.config)
        self.post_infer = self.post_infer_class(self.config)

    def set_scheduler(self, scheduler):
        self.scheduler = scheduler
        self.transformer_infer.set_scheduler(scheduler)

    def to_cpu(self):
        self.pre_weight.to_cpu()
        self.post_weight.to_cpu()
        self.transformer_weights.to_cpu()

    def to_cuda(self):
        self.pre_weight.to_cuda()
        self.post_weight.to_cuda()
        self.transformer_weights.to_cuda()

    @torch.no_grad()
    def infer(self, inputs):
        t = self.scheduler.timesteps[self.scheduler.step_index]
        text_encoder_output = inputs["text_encoder_output"]["context"]
        do_classifier_free_guidance = self.config.guidance_scale > 1.0
        latent_model_input = self.scheduler.latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
        timestep = t.expand(latent_model_input.shape[0])

        hidden_states, encoder_hidden_states, emb, infer_shapes = self.pre_infer.infer(
            self.pre_weight,
            latent_model_input[0],
            timestep,
            text_encoder_output[0],
        )

        hidden_states, encoder_hidden_states = self.transformer_infer.infer(
            self.transformer_weights,
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            temb=emb,
        )

        noise_pred = self.post_infer.infer(self.post_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, infer_shapes=infer_shapes)

        noise_pred = noise_pred.float()

        if self.config.use_dynamic_cfg:  # True
            self.scheduler.guidance_scale = 1 + self.scheduler.guidance_scale * ((1 - math.cos(math.pi * ((self.scheduler.infer_steps - t.item()) / self.scheduler.infer_steps) ** 5.0)) / 2)

        if do_classifier_free_guidance:  # False
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.scheduler.guidance_scale * (noise_pred_text - noise_pred_uncond)

        self.scheduler.noise_pred = noise_pred