tensor.py 1.82 KB
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
import torch
PengGao's avatar
PengGao committed
2

gushiqiao's avatar
gushiqiao committed
3
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
4
from lightx2v.utils.registry_factory import TENSOR_REGISTER
5
6
7
8


@TENSOR_REGISTER("Default")
class DefaultTensor:
9
    def __init__(self, tensor_name, lazy_load=False, lazy_load_file=None):
10
        self.tensor_name = tensor_name
11
12
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
13
14
        self.infer_dtype = GET_DTYPE()
        self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
15
16
17

    def load_from_disk(self):
        if not torch._dynamo.is_compiling():
18
            self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype).pin_memory()
19
        else:
20
            self.tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
21
22

    def load(self, weight_dict):
23
        if not self.lazy_load:
gushiqiao's avatar
gushiqiao committed
24
            self.tensor = weight_dict[self.tensor_name]
25
26
27
            self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)

    def clear(self):
gushiqiao's avatar
FIX  
gushiqiao committed
28
        attrs = ["tensor", "pinned_tensor"]
gushiqiao's avatar
gushiqiao committed
29
30
31
32
        for attr in attrs:
            if hasattr(self, attr):
                delattr(self, attr)
                setattr(self, attr, None)
33
34
35

    def _calculate_size(self):
        return self.tensor.numel() * self.tensor.element_size()
36
37

    def to_cpu(self, non_blocking=False):
38
39
40
41
        if hasattr(self, "pinned_tensor"):
            self.tensor = self.pinned_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
        else:
            self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
42
43
44

    def to_cuda(self, non_blocking=False):
        self.tensor = self.tensor.cuda(non_blocking=non_blocking)
45
46
47
48
49
50

    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.tensor_name] = self.tensor.cpu().detach().clone()
        return destination