tensor.py 3.98 KB
Newer Older
1
2
import re

Xinchi Huang's avatar
Xinchi Huang committed
3
import torch
PengGao's avatar
PengGao committed
4

gushiqiao's avatar
gushiqiao committed
5
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
6
from lightx2v.utils.registry_factory import TENSOR_REGISTER
7
from lightx2v_platform.base.global_var import AI_DEVICE
8
9
10
11


@TENSOR_REGISTER("Default")
class DefaultTensor:
12
    def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
13
        self.tensor_name = tensor_name
14
15
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
16
17
        self.is_post_adapter = is_post_adapter
        self.create_cuda_buffer = create_cuda_buffer
18
        self.create_cpu_buffer = create_cpu_buffer
19
20
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
21

22
23
24
25
26
    def load(self, weight_dict):
        if self.create_cuda_buffer:
            self._load_cuda_buffer(weight_dict)
        elif self.create_cpu_buffer:
            self._load_cpu_pin_buffer()
27
        else:
28
            self._load_default_tensors(weight_dict)
29

30
    def _load_default_tensors(self, weight_dict):
31
        if not self.lazy_load:
32
33
34
35
36
            device = weight_dict[self.tensor_name].device
            if device.type == "cpu":
                tensor = weight_dict[self.tensor_name]
                self.pin_tensor = self._create_cpu_pin_tensor(tensor)
                del weight_dict[self.tensor_name]
37
            else:
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                self.tensor = weight_dict[self.tensor_name]

    def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
        if self.lazy_load:
            tensor = self.lazy_load_file.get_tensor(self.tensor_name)
            if use_infer_dtype:
                tensor = tensor.to(self.infer_dtype)
        else:
            tensor = weight_dict[self.tensor_name]
        return tensor

    def _create_cpu_pin_tensor(self, tensor):
        pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
        pin_tensor.copy_(tensor)
        del tensor
        return pin_tensor

    def _load_cuda_buffer(self, weight_dict):
        tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
        self.tensor_cuda_buffer = tensor.to(AI_DEVICE)

    def _load_cpu_pin_buffer(self):
        tensor = self._get_tensor(use_infer_dtype=True)
        self.pin_tensor = self._create_cpu_pin_tensor(tensor)
62

63
    def to_cuda(self, non_blocking=False):
64
        self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
65

66
    def to_cpu(self, non_blocking=False):
67
68
        if hasattr(self, "pin_tensor"):
            self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
69
70
        else:
            self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
71

72
73
74
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
75
        destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
76
        return destination
77
78
79
80
81
82
83
84
85
86
87

    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
        if tensor_name not in destination:
            self.tensor = None
            return
        self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
88
89
90
91
92
93
94
95
96
97
98

    def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)

        tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
        self.pin_tensor = self.pin_tensor.copy_(tensor)
        del tensor