conv3d.py 3.64 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from abc import ABCMeta, abstractmethod

import torch
from loguru import logger

from lightx2v.common.ops.utils import *
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE


class Conv3dWeightTemplate(metaclass=ABCMeta):
    def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1, lora_prefix="diffusion_model.blocks"):
        self.weight_name = weight_name
        self.bias_name = bias_name
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.config = {}
        self.lora_prefix = lora_prefix
        self.has_lora_branch = False
        self.has_diff = False
        self._get_base_attrs_mapping()
        self._get_lora_attr_mapping()

    def _get_base_attrs_mapping(self):
        self.base_attrs = []
        self.base_attrs.append((self.weight_name, "weight", False))
        self.base_attrs.append((self.bias_name, "bias", False))

    def _get_lora_attr_mapping(self):
        _, _, _, self.weight_diff_name, self.bias_diff_name = build_lora_and_diff_names(self.weight_name, self.lora_prefix)
        self.lora_attrs = {
            "weight_diff": "weight_diff_name",
            "bias_diff": "bias_diff_name",
        }
        self.weight_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)
        self.bias_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)

    def register_diff(self, weight_dict):
        if self.weight_diff_name in weight_dict:
            self.weight_diff = weight_dict[self.weight_diff_name]
            logger.debug(f"Register Diff to {self.weight_name}")
        if self.bias_diff_name in weight_dict:
            self.bias_diff = weight_dict[self.bias_diff_name]
            logger.debug(f"Register Diff to {self.bias_name}")

    def set_config(self, config=None):
        if config is not None:
            self.config = config

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
    def apply(self, input_tensor):
        pass


@CONV3D_WEIGHT_REGISTER("Default")
class Conv3dWeight(Conv3dWeightTemplate):
    def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1, lora_prefix="diffusion_model.blocks"):
        super().__init__(weight_name, bias_name, stride, padding, dilation, groups, lora_prefix)

    def load(self, weight_dict):
        device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
        self.weight = device_tensors.get("weight")
        self.bias = device_tensors.get("bias")
        self.pin_weight = pin_tensors.get("weight")
        self.pin_bias = pin_tensors.get("bias")

    def apply(self, input_tensor):
        output_tensor = torch.nn.functional.conv3d(
            input_tensor,
            weight=self.weight + self.weight_diff,
            bias=self.bias + self.bias_diff,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
        )
        return output_tensor

    def to_cuda(self, non_blocking=False):
        move_attr_to_cuda(self, self.base_attrs, self.lora_attrs, non_blocking)

    def to_cpu(self, non_blocking=False):
        move_attr_to_cpu(self, self.base_attrs, self.lora_attrs, non_blocking)

    def state_dict(self, destination=None):
        return state_dict(self, self.base_attrs, self.lora_attrs, destination)

    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        return load_state_dict(self, self.base_attrs, self.lora_attrs, destination, block_index, adapter_block_index)