tensor.py 4.44 KB
Newer Older
1
import os
2
3
import re

Xinchi Huang's avatar
Xinchi Huang committed
4
import torch
5
from safetensors import safe_open
PengGao's avatar
PengGao committed
6

gushiqiao's avatar
gushiqiao committed
7
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
8
from lightx2v.utils.registry_factory import TENSOR_REGISTER
9
from lightx2v_platform.base.global_var import AI_DEVICE
10
11
12
13


@TENSOR_REGISTER("Default")
class DefaultTensor:
14
    def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
15
        self.tensor_name = tensor_name
16
17
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
18
19
        self.is_post_adapter = is_post_adapter
        self.create_cuda_buffer = create_cuda_buffer
20
        self.create_cpu_buffer = create_cpu_buffer
21
22
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
23

24
25
26
27
28
    def load(self, weight_dict):
        if self.create_cuda_buffer:
            self._load_cuda_buffer(weight_dict)
        elif self.create_cpu_buffer:
            self._load_cpu_pin_buffer()
29
        else:
30
            self._load_default_tensors(weight_dict)
31

32
    def _load_default_tensors(self, weight_dict):
33
        if not self.lazy_load:
34
35
36
37
38
            device = weight_dict[self.tensor_name].device
            if device.type == "cpu":
                tensor = weight_dict[self.tensor_name]
                self.pin_tensor = self._create_cpu_pin_tensor(tensor)
                del weight_dict[self.tensor_name]
39
            else:
40
41
42
43
                self.tensor = weight_dict[self.tensor_name]

    def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
        if self.lazy_load:
44
45
46
47
48
            lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
            with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
                tensor = lazy_load_file.get_tensor(self.tensor_name)
                if use_infer_dtype:
                    tensor = tensor.to(self.infer_dtype)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        else:
            tensor = weight_dict[self.tensor_name]
        return tensor

    def _create_cpu_pin_tensor(self, tensor):
        pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
        pin_tensor.copy_(tensor)
        del tensor
        return pin_tensor

    def _load_cuda_buffer(self, weight_dict):
        tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
        self.tensor_cuda_buffer = tensor.to(AI_DEVICE)

    def _load_cpu_pin_buffer(self):
        tensor = self._get_tensor(use_infer_dtype=True)
        self.pin_tensor = self._create_cpu_pin_tensor(tensor)
66

67
    def to_cuda(self, non_blocking=False):
68
        self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
69

70
    def to_cpu(self, non_blocking=False):
71
72
        if hasattr(self, "pin_tensor"):
            self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
73
74
        else:
            self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
75

76
77
78
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
79
        destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
80
        return destination
81
82
83
84
85
86
87
88
89
90
91

    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
        if tensor_name not in destination:
            self.tensor = None
            return
        self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
92
93
94
95
96
97
98

    def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
99
100
101
102
        lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
        with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
            tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
            self.pin_tensor = self.pin_tensor.copy_(tensor)
103
        del tensor