tensor.py 4.71 KB
Newer Older
1
import os
2
import re
Gu Shiqiao's avatar
Gu Shiqiao committed
3
from pathlib import Path
4

Xinchi Huang's avatar
Xinchi Huang committed
5
import torch
6
from safetensors import safe_open
PengGao's avatar
PengGao committed
7

gushiqiao's avatar
gushiqiao committed
8
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
9
from lightx2v.utils.registry_factory import TENSOR_REGISTER
10
from lightx2v_platform.base.global_var import AI_DEVICE
11
12
13
14


@TENSOR_REGISTER("Default")
class DefaultTensor:
15
    def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
16
        self.tensor_name = tensor_name
17
18
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
19
20
        self.is_post_adapter = is_post_adapter
        self.create_cuda_buffer = create_cuda_buffer
21
        self.create_cpu_buffer = create_cpu_buffer
22
23
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
24

25
26
27
28
29
    def load(self, weight_dict):
        if self.create_cuda_buffer:
            self._load_cuda_buffer(weight_dict)
        elif self.create_cpu_buffer:
            self._load_cpu_pin_buffer()
30
        else:
31
            self._load_default_tensors(weight_dict)
32

33
    def _load_default_tensors(self, weight_dict):
34
        if not self.lazy_load:
35
36
37
38
39
            device = weight_dict[self.tensor_name].device
            if device.type == "cpu":
                tensor = weight_dict[self.tensor_name]
                self.pin_tensor = self._create_cpu_pin_tensor(tensor)
                del weight_dict[self.tensor_name]
40
            else:
41
42
43
44
                self.tensor = weight_dict[self.tensor_name]

    def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
        if self.lazy_load:
Gu Shiqiao's avatar
Gu Shiqiao committed
45
46
47
48
            if Path(self.lazy_load_file).is_file():
                lazy_load_file_path = self.lazy_load_file
            else:
                lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
49
50
51
52
            with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
                tensor = lazy_load_file.get_tensor(self.tensor_name)
                if use_infer_dtype:
                    tensor = tensor.to(self.infer_dtype)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        else:
            tensor = weight_dict[self.tensor_name]
        return tensor

    def _create_cpu_pin_tensor(self, tensor):
        pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
        pin_tensor.copy_(tensor)
        del tensor
        return pin_tensor

    def _load_cuda_buffer(self, weight_dict):
        tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
        self.tensor_cuda_buffer = tensor.to(AI_DEVICE)

    def _load_cpu_pin_buffer(self):
        tensor = self._get_tensor(use_infer_dtype=True)
        self.pin_tensor = self._create_cpu_pin_tensor(tensor)
70

71
    def to_cuda(self, non_blocking=False):
72
        self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
73

74
    def to_cpu(self, non_blocking=False):
75
76
        if hasattr(self, "pin_tensor"):
            self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
77
78
        else:
            self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
79

80
81
82
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
83
        destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
84
        return destination
85
86
87
88
89
90
91
92
93
94
95

    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
        if tensor_name not in destination:
            self.tensor = None
            return
        self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
96
97
98
99
100
101
102

    def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
Gu Shiqiao's avatar
Gu Shiqiao committed
103
104
105
106
        if Path(self.lazy_load_file).is_file():
            lazy_load_file_path = self.lazy_load_file
        else:
            lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
107
108
109
        with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
            tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
            self.pin_tensor = self.pin_tensor.copy_(tensor)
110
        del tensor