tensor.py 3.26 KB
Newer Older
1
2
import re

Xinchi Huang's avatar
Xinchi Huang committed
3
import torch
PengGao's avatar
PengGao committed
4

gushiqiao's avatar
gushiqiao committed
5
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
6
from lightx2v.utils.registry_factory import TENSOR_REGISTER
7
8
9
10


@TENSOR_REGISTER("Default")
class DefaultTensor:
11
    def __init__(self, tensor_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
12
        self.tensor_name = tensor_name
13
14
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
15
16
        self.is_post_adapter = is_post_adapter
        self.create_cuda_buffer = create_cuda_buffer
17
18
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
19
20
21

    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
22
            self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype).pin_memory()
23
        else:
24
            self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
25
26

    def load(self, weight_dict):
27
        if not self.lazy_load:
28
29
30
31
            if self.create_cuda_buffer:
                self.tensor_cuda_buffer = weight_dict[self.tensor_name].cuda()
            else:
                device = weight_dict[self.tensor_name].device
Kane's avatar
Kane committed
32
                if device.type in ["cuda", "mlu", "npu"]:
33
34
35
36
37
38
39
40
41
                    self.tensor = weight_dict[self.tensor_name]
                elif device.type == "cpu":
                    tensor_shape = weight_dict[self.tensor_name].shape
                    tensor_dtype = weight_dict[self.tensor_name].dtype
                    self.pin_tensor = torch.empty(tensor_shape, pin_memory=True, dtype=tensor_dtype)
                    self.pin_tensor.copy_(weight_dict[self.tensor_name])
                    del weight_dict[self.tensor_name]
                else:
                    raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
42
43

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
44
        attrs = ["tensor", "pinned_tensor"]
gushiqiao's avatar
gushiqiao committed
45
46
47
48
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)
49
50
51

    def _calculate_size(self):
        return self.tensor.numel() * self.tensor.element_size()
52

53
54
55
    def to_cuda(self, non_blocking=False):
        self.tensor = self.pin_tensor.cuda(non_blocking=non_blocking)

56
    def to_cpu(self, non_blocking=False):
57
58
        if hasattr(self, "pin_tensor"):
            self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
59
60
        else:
            self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
61

62
63
64
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
65
        destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
66
        return destination
67
68
69
70
71
72
73
74
75
76
77
78

    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
        else:
            tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)

        if tensor_name not in destination:
            self.tensor = None
            return
        self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)